├── .bandit ├── .dockerignore ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── README.md ├── api ├── __init__.py ├── metadata.py └── predict.py ├── app.py ├── config.py ├── core ├── __init__.py ├── getpoint │ ├── attention_decoder.py │ ├── batcher.py │ ├── beam_search.py │ ├── convert.py │ ├── data.py │ ├── decode.py │ ├── inspect_checkpoint.py │ ├── model.py │ ├── run_summarization.py │ └── util.py ├── model.py └── util.py ├── dependabot.yml ├── docs ├── deploy-max-to-ibm-cloud-with-kubernetes-button.png └── swagger-screenshot.png ├── max-text-summarizer.yaml ├── requirements-test.txt ├── requirements.txt ├── samples ├── README.md ├── sample1.json ├── sample2.json └── sample3.json ├── sha512sums.txt └── tests └── test.py /.bandit: -------------------------------------------------------------------------------- 1 | [bandit] 2 | exclude: /tests,/training 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | README.* 2 | __pycache__ 3 | .gitignore 4 | .git/ 5 | tests/ 6 | samples/ 7 | docs/ 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | .idea 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 3.6 4 | services: 5 | - docker 6 | install: 7 | - docker build -t max-text-summarizer . 8 | - docker run -it -d --rm -p 5000:5000 max-text-summarizer 9 | - pip install -r requirements-test.txt 10 | script: 11 | - flake8 . --max-line-length=127 --exclude=./core/getpoint 12 | - bandit -r . 13 | - sleep 30 14 | # "python -m" will add current working directory to sys.path. 15 | # See https://docs.pytest.org/en/latest/usage.html#calling-pytest-through-python-m-pytest 16 | - python -m pytest tests/test.py 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | FROM quay.io/codait/max-base:v1.1.3 18 | 19 | # Upgrade packages to meet security criteria 20 | RUN apt-get update && apt-get upgrade -y && rm -rf /var/lib/apt/lists/* 21 | 22 | # Fill in these with a link to the bucket containing the model and the model file name 23 | ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-text-summarizer/1.0.0 24 | ARG model_file=assets.tar.gz 25 | 26 | RUN useradd --create-home max 27 | RUN chown -R max:max /opt/conda 28 | USER max 29 | WORKDIR /home/max 30 | RUN mkdir assets 31 | 32 | RUN wget -nv --show-progress --progress=bar:force:noscroll ${model_bucket}/${model_file} --output-document=assets/${model_file} && \ 33 | tar -x -C assets/ -f assets/${model_file} -v && rm assets/${model_file} 34 | 35 | COPY requirements.txt . 36 | RUN pip install -r requirements.txt 37 | 38 | COPY . . 39 | 40 | # check file integrity 41 | RUN sha512sum -c sha512sums.txt 42 | 43 | EXPOSE 5000 44 | 45 | CMD python app.py 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.com/IBM/MAX-Text-Summarizer.svg?branch=master)](https://travis-ci.com/IBM/MAX-Text-Summarizer) 2 | [![API demo](https://img.shields.io/website/http/max-text-summarizer.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud/swagger.json.svg?label=API%20demo&down_message=down&up_message=up)](http://max-text-summarizer.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud) 3 | 4 | [](http://ibm.biz/max-to-ibm-cloud-tutorial) 5 | 6 | # IBM Developer Model Asset Exchange: Text Summarizer 7 | 8 | This repository contains code to instantiate and deploy a text summarization model. The model takes a JSON input that encapsulates some text snippets and returns a text summary that represents the key information or message in the input text. 9 | The model was trained on the [CNN / Daily Mail](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail) dataset. 10 | The model has a vocabulary of approximately 200k words. The model is based on the ACL 2017 paper, [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368). 11 | 12 | The model files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-text-summarizer/1.0.0/assets.tar.gz). The code in this repository deploys the model as a web service in a Docker container. This repository was developed as part of the [IBM Developer Model Asset Exchange](https://developer.ibm.com/code/exchanges/models/) and the public API is powered by [IBM Cloud](https://ibm.biz/Bdz2XM). 13 | 14 | ## Model Metadata 15 | | Domain | Application | Industry | Framework | Training Data | Input Data Format | 16 | | ------------- | -------- | -------- | --------- | --------- | -------------- | 17 | | NLP | Text Summarization | General | TensorFlow | CNN / Daily Mail | Text | 18 | 19 | 20 | ## References: 21 | 22 | * _A. See, P. J. Liu, C. D. Manning_, [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368), ACL, 2017. 23 | 24 | * [The text summarization repository](https://github.com/becxer/pointer-generator/) 25 | 26 | ## Licenses 27 | 28 | | Component | License | Link | 29 | | ------------- | -------- | -------- | 30 | | This Repository | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/IBM/MAX-Text-Summarizer/blob/master/LICENSE) | 31 | | Third Party Code | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/becxer/pointer-generator/blob/master/LICENSE.txt) | 32 | | Pre-Trained Model Weights | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/becxer/pointer-generator/) | 33 | | Training Data | [MIT](https://opensource.org/licenses/MIT) | [LICENSE](https://github.com/abisee/cnn-dailymail/blob/master/LICENSE.md) | [CNN / Daily Mail ](https://github.com/abisee/cnn-dailymail) | 34 | 35 | ## Pre-requisites: 36 | 37 | * `docker`: The [Docker](https://www.docker.com/) command-line interface. Follow the [installation instructions](https://docs.docker.com/install/) for your system. 38 | * The minimum recommended resources for this model is 4GB Memory and 4 CPUs. 39 | * If you are on x86-64/AMD64, your CPU must support AVX at the minimum. 40 | 41 | # Steps 42 | 43 | 1. [Deploy from Quay](#deploy-from-quay) 44 | 2. [Deploy on Kubernetes](#deploy-on-kubernetes) 45 | 3. [Run Locally](#run-locally) 46 | 47 | ## Deploy from Quay 48 | 49 | To run the docker image, which automatically starts the model serving API, run: 50 | 51 | ``` 52 | $ docker run -it -p 5000:5000 quay.io/codait/max-text-summarizer 53 | ``` 54 | 55 | This will pull a pre-built image from the Quay.io container registry (or use an existing image if already cached locally) and run it. 56 | If you'd rather checkout and build the model locally you can follow the [run locally](#run-locally) steps below. 57 | 58 | ## Deploy on Kubernetes 59 | 60 | You can also deploy the model on Kubernetes using the latest docker image on Quay. 61 | 62 | On your Kubernetes cluster, run the following commands: 63 | 64 | ``` 65 | $ kubectl apply -f https://github.com/IBM/MAX-Text-Summarizer/raw/master/max-text-summarizer.yaml 66 | ``` 67 | 68 | The model will be available internally at port `5000`, but can also be accessed externally through the `NodePort`. 69 | 70 | A more elaborate tutorial on how to deploy this MAX model to production on [IBM Cloud](https://ibm.biz/Bdz2XM) can be found [here](http://ibm.biz/max-to-ibm-cloud-tutorial). 71 | 72 | ## Run Locally 73 | 74 | 1. [Build the Model](#1-build-the-model) 75 | 2. [Deploy the Model](#2-deploy-the-model) 76 | 3. [Use the Model](#3-use-the-model) 77 | 4. [Development](#4-development) 78 | 5. [Cleanup](#5-cleanup) 79 | 80 | 81 | ### 1. Build the Model 82 | 83 | Clone this repository locally. In a terminal, run the following command: 84 | 85 | ``` 86 | $ git clone https://github.com/IBM/MAX-Text-Summarizer.git 87 | ``` 88 | 89 | Change directory into the repository base folder: 90 | 91 | ``` 92 | $ cd MAX-Text-Summarizer 93 | ``` 94 | 95 | To build the docker image locally, run: 96 | 97 | ``` 98 | $ docker build -t max-text-summarizer . 99 | ``` 100 | 101 | All required model assets will be downloaded during the build process. _Note_ that currently this docker image is CPU only (we will add support for GPU images later). 102 | 103 | 104 | ### 2. Deploy the Model 105 | 106 | To run the docker image, which automatically starts the model serving API, run: 107 | 108 | ``` 109 | $ docker run -it -p 5000:5000 max-text-summarizer 110 | ``` 111 | 112 | ### 3. Use the Model 113 | 114 | The API server automatically generates an interactive Swagger documentation page. Go to `http://localhost:5000` to load it. From there you can explore the API and also create test requests. 115 | 116 | Use the `model/predict` endpoint to load some seed text (you can use one of the test files from the `samples` folder) and get predicted output from the API. 117 | 118 | ![Swagger UI Screenshot](docs/swagger-screenshot.png) 119 | 120 | You can also test it on the command line, for example: 121 | 122 | ```bash 123 | $ curl -d @samples/sample1.json -H "Content-Type: application/json" -XPOST http://localhost:5000/model/predict 124 | ``` 125 | 126 | You should see a JSON response like that below: 127 | 128 | ```json 129 | { 130 | "status": "ok", 131 | "summary_text": [ 132 | ["nick gordon 's father -lrb- left and right -rrb- gave an interview about the 25-year-old fiance of bobbi kristina brown . it has been reported that gordon , 25 , has threatened suicide and has been taking xanax since . whitney houston 's daughter was found unconscious in a bathtub in january . on wednesday , access hollywood spoke exclusively to gordon 's stepfather about his son 's state of mind . "] 133 | ] 134 | } 135 | ``` 136 | 137 | The text summarizer preserves in the output summary text some special characters such as `-lrb-` (representing `(`), `-rrb-` (representing `)`), etc. that appear in the input [sample](samples/sample1.json), which is excerpted from the [Daily Mail](https://github.com/abisee/cnn-dailymail) dataset. 138 | 139 | ### 4. Development 140 | 141 | To run the Flask API app in debug mode, edit `config.py` to set `DEBUG = True` under the application settings. You will then need to rebuild the docker image (see [step 1](#1-build-the-model)). 142 | 143 | ### 5. Cleanup 144 | 145 | To stop the Docker container, type `CTRL` + `C` in your terminal. 146 | 147 | ## Resources and Contributions 148 | 149 | If you are interested in contributing to the Model Asset Exchange project or have any queries, please follow the instructions [here](https://github.com/CODAIT/max-central-repo). 150 | -------------------------------------------------------------------------------- /api/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .metadata import ModelMetadataAPI # noqa 18 | from .predict import ModelPredictAPI # noqa 19 | -------------------------------------------------------------------------------- /api/metadata.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from core.model import ModelWrapper 18 | from maxfw.core import MAX_API, MetadataAPI, METADATA_SCHEMA 19 | 20 | 21 | class ModelMetadataAPI(MetadataAPI): 22 | 23 | @MAX_API.marshal_with(METADATA_SCHEMA) 24 | def get(self): 25 | """Return the metadata associated with the model""" 26 | return ModelWrapper.MODEL_META_DATA 27 | -------------------------------------------------------------------------------- /api/predict.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | from core.model import ModelWrapper 19 | from maxfw.core import MAX_API, PredictAPI 20 | from flask_restplus import fields 21 | 22 | input_parser = MAX_API.model('ModelInput', { 23 | 'text': fields.List(fields.String, required=True, description=( 24 | 'A list of input text to be summarized. ' 25 | 'Each entry in the list is treated as a separate input text and so a summary result will be returned for each entry.')) 26 | }) 27 | 28 | predict_response = MAX_API.model('ModelPredictResponse', { 29 | 'status': fields.String(required=True, description='Response status message.'), 30 | 'summary_text': fields.List(fields.String, required=True, description=( 31 | 'Generated summary text. Each entry in the list is the summary result for the corresponding entry in the input list.')) 32 | }) 33 | 34 | logger = logging.getLogger() 35 | 36 | 37 | class ModelPredictAPI(PredictAPI): 38 | 39 | model_wrapper = ModelWrapper() 40 | 41 | @MAX_API.doc('predict') 42 | @MAX_API.expect(input_parser, validate=True) 43 | @MAX_API.marshal_with(predict_response) 44 | def post(self): 45 | """Make a prediction given input data""" 46 | result = {'status': 'error'} 47 | result['summary_text'] = [] 48 | 49 | input_json = MAX_API.payload 50 | texts = input_json['text'] 51 | for text in texts: 52 | preds = self.model_wrapper.predict(text) 53 | result['summary_text'].append(preds) 54 | 55 | result['status'] = 'ok' 56 | 57 | return result 58 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from maxfw.core import MAXApp 18 | from api import ModelMetadataAPI, ModelPredictAPI 19 | from config import API_TITLE, API_DESC, API_VERSION 20 | 21 | max = MAXApp(API_TITLE, API_DESC, API_VERSION) 22 | max.add_api(ModelMetadataAPI, '/metadata') 23 | max.add_api(ModelPredictAPI, '/predict') 24 | max.run() 25 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import os 18 | import pathlib 19 | 20 | # Flask settings 21 | DEBUG = False 22 | 23 | # Flask-restplus settings 24 | RESTPLUS_MASK_SWAGGER = False 25 | SWAGGER_UI_DOC_EXPANSION = 'none' 26 | 27 | # API metadata 28 | API_TITLE = 'MAX Text Summarizer' 29 | API_DESC = 'Generate a summarized description of a body of text.' 30 | API_VERSION = '1.0.0' 31 | 32 | # default model 33 | MODEL_NAME = 'get_to_the_point' 34 | ASSET_DIR = pathlib.Path('./assets').absolute() 35 | DEFAULT_MODEL_PATH = os.path.join(ASSET_DIR, MODEL_NAME) 36 | DEFAULT_VOCAB_PATH = os.path.join(ASSET_DIR, 'vocab') 37 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/core/__init__.py -------------------------------------------------------------------------------- /core/getpoint/attention_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file defines the decoder""" 18 | 19 | import tensorflow as tf 20 | from tensorflow.python.ops import variable_scope 21 | from tensorflow.python.ops import array_ops 22 | from tensorflow.python.ops import nn_ops 23 | from tensorflow.python.ops import math_ops 24 | 25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated. 26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention 27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None): 28 | """ 29 | Args: 30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 31 | initial_state: 2D Tensor [batch_size x cell.state_size]. 32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size]. 33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1). 34 | cell: rnn_cell.RNNCell defining the cell function and size. 35 | initial_state_attention: 36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step). 37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step. 38 | use_coverage: boolean. If True, use coverage mechanism. 39 | prev_coverage: 40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage. 41 | 42 | Returns: 43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 44 | shape [batch_size x cell.output_size]. The output vectors. 45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size]. 46 | attn_dists: A list containing tensors of shape (batch_size,attn_length). 47 | The attention distributions for each decoder step. 48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False. 49 | coverage: Coverage vector on the last step computed. None if use_coverage=False. 50 | """ 51 | with variable_scope.variable_scope("attention_decoder") as scope: 52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined 53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined 54 | 55 | # Reshape encoder_states (need to insert a dim) 56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size) 57 | 58 | # To calculate attention, we calculate 59 | # v^T tanh(W_h h_i + W_s s_t + b_attn) 60 | # where h_i is an encoder state, and s_t a decoder state. 61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t). 62 | # We set it to be equal to the size of the encoder states. 63 | attention_vec_size = attn_size 64 | 65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features 66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size]) 67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size) 68 | 69 | # Get the weight vectors v and w_c (w_c is for coverage) 70 | v = variable_scope.get_variable("v", [attention_vec_size]) 71 | if use_coverage: 72 | with variable_scope.variable_scope("coverage"): 73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size]) 74 | 75 | if prev_coverage is not None: # for beam search mode with coverage 76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1) 77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3) 78 | 79 | def attention(decoder_state, coverage=None): 80 | """Calculate the context vector and attention distribution from the decoder state. 81 | 82 | Args: 83 | decoder_state: state of the decoder 84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1). 85 | 86 | Returns: 87 | context_vector: weighted sum of encoder_states 88 | attn_dist: attention distribution 89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1) 90 | """ 91 | with variable_scope.variable_scope("Attention"): 92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper) 93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size) 94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size) 95 | 96 | def masked_attention(e): 97 | """Take softmax of e then apply enc_padding_mask and re-normalize""" 98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length) 99 | attn_dist *= enc_padding_mask # apply mask 100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size) 101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize 102 | 103 | if use_coverage and coverage is not None: # non-first step of coverage 104 | # Multiply coverage vector by w_c to get coverage_features. 105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size) 106 | 107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn) 108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length) 109 | 110 | # Calculate attention distribution 111 | attn_dist = masked_attention(e) 112 | 113 | # Update coverage vector 114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) 115 | else: 116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn) 117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e 118 | 119 | # Calculate attention distribution 120 | attn_dist = masked_attention(e) 121 | 122 | if use_coverage: # first step of training 123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage 124 | 125 | # Calculate the context vector from attn_dist and encoder_states 126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size). 127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size]) 128 | 129 | return context_vector, attn_dist, coverage 130 | 131 | outputs = [] 132 | attn_dists = [] 133 | p_gens = [] 134 | state = initial_state 135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in 136 | context_vector = array_ops.zeros([batch_size, attn_size]) 137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set. 138 | if initial_state_attention: # true in decode mode 139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input 140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector 141 | for i, inp in enumerate(decoder_inputs): 142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs)) 143 | if i > 0: 144 | variable_scope.get_variable_scope().reuse_variables() 145 | 146 | # Merge input and previous attentions into one vector x of the same size as inp 147 | input_size = inp.get_shape().with_rank(2)[1] 148 | if input_size.value is None: 149 | raise ValueError("Could not infer input size from input: %s" % inp.name) 150 | x = linear([inp] + [context_vector], input_size, True) 151 | 152 | # Run the decoder RNN cell. cell_output = decoder state 153 | cell_output, state = cell(x, state) 154 | 155 | # Run the attention mechanism. 156 | if i == 0 and initial_state_attention: # always true in decode mode 157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call 158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update 159 | else: 160 | context_vector, attn_dist, coverage = attention(state, coverage) 161 | attn_dists.append(attn_dist) 162 | 163 | # Calculate p_gen 164 | if pointer_gen: 165 | with tf.variable_scope('calculate_pgen'): 166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar 167 | p_gen = tf.sigmoid(p_gen) 168 | p_gens.append(p_gen) 169 | 170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer 171 | # This is V[s_t, h*_t] + b in the paper 172 | with variable_scope.variable_scope("AttnOutputProjection"): 173 | output = linear([cell_output] + [context_vector], cell.output_size, True) 174 | outputs.append(output) 175 | 176 | # If using coverage, reshape it 177 | if coverage is not None: 178 | coverage = array_ops.reshape(coverage, [batch_size, -1]) 179 | 180 | return outputs, state, attn_dists, p_gens, coverage 181 | 182 | 183 | 184 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 186 | 187 | Args: 188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 189 | output_size: int, second dimension of W[i]. 190 | bias: boolean, whether to add a bias term or not. 191 | bias_start: starting value to initialize the bias; 0 by default. 192 | scope: VariableScope for the created subgraph; defaults to "Linear". 193 | 194 | Returns: 195 | A 2D Tensor with shape [batch x output_size] equal to 196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 197 | 198 | Raises: 199 | ValueError: if some of the arguments has unspecified or wrong shape. 200 | """ 201 | if args is None or (isinstance(args, (list, tuple)) and not args): 202 | raise ValueError("`args` must be specified") 203 | if not isinstance(args, (list, tuple)): 204 | args = [args] 205 | 206 | # Calculate the total size of arguments on dimension 1. 207 | total_arg_size = 0 208 | shapes = [a.get_shape().as_list() for a in args] 209 | for shape in shapes: 210 | if len(shape) != 2: 211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) 212 | if not shape[1]: 213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) 214 | else: 215 | total_arg_size += shape[1] 216 | 217 | # Now the computation. 218 | with tf.variable_scope(scope or "Linear"): 219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size]) 220 | if len(args) == 1: 221 | res = tf.matmul(args[0], matrix) 222 | else: 223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix) 224 | if not bias: 225 | return res 226 | bias_term = tf.get_variable( 227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start)) 228 | return res + bias_term 229 | -------------------------------------------------------------------------------- /core/getpoint/batcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to process data into batches""" 18 | 19 | import queue as Queue 20 | from random import shuffle 21 | from threading import Thread 22 | import time 23 | import numpy as np 24 | import tensorflow as tf 25 | import data 26 | 27 | 28 | class Example(object): 29 | """Class representing a train/val/test example for text summarization.""" 30 | 31 | def __init__(self, article, abstract_sentences, vocab, hps): 32 | """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. 33 | 34 | Args: 35 | article: source text; a string. each token is separated by a single space. 36 | abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space. 37 | vocab: Vocabulary object 38 | hps: hyperparameters 39 | """ 40 | self.hps = hps 41 | 42 | # Get ids of special tokens 43 | start_decoding = vocab.word2id(data.START_DECODING) 44 | stop_decoding = vocab.word2id(data.STOP_DECODING) 45 | 46 | # Process the article 47 | article_words = article.split() 48 | if len(article_words) > hps.max_enc_steps: 49 | article_words = article_words[:hps.max_enc_steps] 50 | self.enc_len = len(article_words) # store the length after truncation but before padding 51 | self.enc_input = [vocab.word2id(w) for w in 52 | article_words] # list of word ids; OOVs are represented by the id for UNK token 53 | 54 | # Process the abstract 55 | abstract = ' '.join(abstract_sentences) # string 56 | abstract_words = abstract.split() # list of strings 57 | abs_ids = [vocab.word2id(w) for w in 58 | abstract_words] # list of word ids; OOVs are represented by the id for UNK token 59 | 60 | # Get the decoder input sequence and target sequence 61 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding, 62 | stop_decoding) 63 | self.dec_len = len(self.dec_input) 64 | 65 | # If using pointer-generator mode, we need to store some extra info 66 | if hps.pointer_gen: 67 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves 68 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab) 69 | 70 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id 71 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs) 72 | 73 | # Overwrite decoder target sequence so it uses the temp article OOV ids 74 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, 75 | stop_decoding) 76 | 77 | # Store the original strings 78 | self.original_article = article 79 | self.original_abstract = abstract 80 | self.original_abstract_sents = abstract_sentences 81 | 82 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 83 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated). 84 | 85 | Args: 86 | sequence: List of ids (integers) 87 | max_len: integer 88 | start_id: integer 89 | stop_id: integer 90 | 91 | Returns: 92 | inp: sequence length <=max_len starting with start_id 93 | target: sequence same length as input, ending with stop_id only if there was no truncation 94 | """ 95 | inp = [start_id] + sequence[:] 96 | target = sequence[:] 97 | if len(inp) > max_len: # truncate 98 | inp = inp[:max_len] 99 | target = target[:max_len] # no end_token 100 | else: # no truncation 101 | target.append(stop_id) # end token 102 | if len(inp) != len(target): 103 | raise ValueError("len(inp) != len(target)") 104 | return inp, target 105 | 106 | def pad_decoder_inp_targ(self, max_len, pad_id): 107 | """Pad decoder input and target sequences with pad_id up to max_len.""" 108 | while len(self.dec_input) < max_len: 109 | self.dec_input.append(pad_id) 110 | while len(self.target) < max_len: 111 | self.target.append(pad_id) 112 | 113 | def pad_encoder_input(self, max_len, pad_id): 114 | """Pad the encoder input sequence with pad_id up to max_len.""" 115 | while len(self.enc_input) < max_len: 116 | self.enc_input.append(pad_id) 117 | if self.hps.pointer_gen: 118 | while len(self.enc_input_extend_vocab) < max_len: 119 | self.enc_input_extend_vocab.append(pad_id) 120 | 121 | 122 | class Batch(object): 123 | """Class representing a minibatch of train/val/test examples for text summarization.""" 124 | 125 | def __init__(self, example_list, hps, vocab): 126 | """Turns the example_list into a Batch object. 127 | 128 | Args: 129 | example_list: List of Example objects 130 | hps: hyperparameters 131 | vocab: Vocabulary object 132 | """ 133 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences 134 | self.init_encoder_seq(example_list, hps) # initialize the input to the encoder 135 | self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder 136 | self.store_orig_strings(example_list) # store the original strings 137 | 138 | def init_encoder_seq(self, example_list, hps): 139 | """Initializes the following: 140 | self.enc_batch: 141 | numpy array of shape (batch_size, <=max_enc_steps) containing integer ids (all OOVs represented by UNK id), padded to length of longest sequence in the batch 142 | self.enc_lens: 143 | numpy array of shape (batch_size) containing integers. The (truncated) length of each encoder input sequence (pre-padding). 144 | self.enc_padding_mask: 145 | numpy array of shape (batch_size, <=max_enc_steps), containing 1s and 0s. 1s correspond to real tokens in enc_batch and target_batch; 0s correspond to padding. 146 | 147 | If hps.pointer_gen, additionally initializes the following: 148 | self.max_art_oovs: 149 | maximum number of in-article OOVs in the batch 150 | self.art_oovs: 151 | list of list of in-article OOVs (strings), for each example in the batch 152 | self.enc_batch_extend_vocab: 153 | Same as self.enc_batch, but in-article OOVs are represented by their temporary article OOV number. 154 | """ 155 | # Determine the maximum length of the encoder input sequence in this batch 156 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 157 | 158 | # Pad the encoder input sequences up to the length of the longest sequence 159 | for ex in example_list: 160 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id) 161 | 162 | # Initialize the numpy arrays 163 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder. 164 | self.enc_batch = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32) 165 | self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32) 166 | self.enc_padding_mask = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.float32) 167 | 168 | # Fill in the numpy arrays 169 | for i, ex in enumerate(example_list): 170 | self.enc_batch[i, :] = ex.enc_input[:] 171 | self.enc_lens[i] = ex.enc_len 172 | for j in range(ex.enc_len): 173 | self.enc_padding_mask[i][j] = 1 174 | 175 | # For pointer-generator mode, need to store some extra info 176 | if hps.pointer_gen: 177 | # Determine the max number of in-article OOVs in this batch 178 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list]) 179 | # Store the in-article OOVs themselves 180 | self.art_oovs = [ex.article_oovs for ex in example_list] 181 | # Store the version of the enc_batch that uses the article OOV ids 182 | self.enc_batch_extend_vocab = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32) 183 | for i, ex in enumerate(example_list): 184 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:] 185 | 186 | def init_decoder_seq(self, example_list, hps): 187 | """Initializes the following: 188 | self.dec_batch: 189 | numpy array of shape (batch_size, max_dec_steps), containing integer ids as input for the decoder, padded to max_dec_steps length. 190 | self.target_batch: 191 | numpy array of shape (batch_size, max_dec_steps), containing integer ids for the target sequence, padded to max_dec_steps length. 192 | self.dec_padding_mask: 193 | numpy array of shape (batch_size, max_dec_steps), containing 1s and 0s. 1s correspond to real tokens in dec_batch and target_batch; 0s correspond to padding. 194 | """ 195 | # Pad the inputs and targets 196 | for ex in example_list: 197 | ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id) 198 | 199 | # Initialize the numpy arrays. 200 | # Note: our decoder inputs and targets must be the same length for each batch (second dimension = max_dec_steps) because we do not use a dynamic_rnn for decoding. However I believe this is possible, or will soon be possible, with Tensorflow 1.0, in which case it may be best to upgrade to that. 201 | self.dec_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32) 202 | self.target_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32) 203 | self.dec_padding_mask = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.float32) 204 | 205 | # Fill in the numpy arrays 206 | for i, ex in enumerate(example_list): 207 | self.dec_batch[i, :] = ex.dec_input[:] 208 | self.target_batch[i, :] = ex.target[:] 209 | for j in range(ex.dec_len): 210 | self.dec_padding_mask[i][j] = 1 211 | 212 | def store_orig_strings(self, example_list): 213 | """Store the original article and abstract strings in the Batch object""" 214 | self.original_articles = [ex.original_article for ex in example_list] # list of lists 215 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists 216 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists 217 | 218 | 219 | class Batcher(object): 220 | """A class to generate minibatches of data. Buckets examples together based on length of the encoder sequence.""" 221 | 222 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 223 | 224 | def __init__(self, data_path, vocab, hps, single_pass): 225 | """Initialize the batcher. Start threads that process the data into batches. 226 | 227 | Args: 228 | data_path: tf.Example filepattern. 229 | vocab: Vocabulary object 230 | hps: hyperparameters 231 | single_pass: If True, run through the dataset exactly once (useful for when you want to run evaluation on the dev or test set). Otherwise generate random batches indefinitely (useful for training). 232 | """ 233 | self._data_path = data_path 234 | self._vocab = vocab 235 | self._hps = hps 236 | self._single_pass = single_pass 237 | 238 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched 239 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX) 240 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size) 241 | 242 | # Different settings depending on whether we're in single_pass mode or not 243 | if single_pass: 244 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once 245 | self._num_batch_q_threads = 1 # just one thread to batch examples 246 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing 247 | self._finished_reading = False # this will tell us when we're finished reading the dataset 248 | else: 249 | self._num_example_q_threads = 16 # num threads to fill example queue 250 | self._num_batch_q_threads = 4 # num threads to fill batch queue 251 | self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing 252 | 253 | # Start the threads that load the queues 254 | self._example_q_threads = [] 255 | for _ in range(self._num_example_q_threads): 256 | self._example_q_threads.append(Thread(target=self.fill_example_queue)) 257 | self._example_q_threads[-1].daemon = True 258 | self._example_q_threads[-1].start() 259 | self._batch_q_threads = [] 260 | for _ in range(self._num_batch_q_threads): 261 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 262 | self._batch_q_threads[-1].daemon = True 263 | self._batch_q_threads[-1].start() 264 | 265 | # Start a thread that watches the other threads and restarts them if they're dead 266 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 267 | self._watch_thread = Thread(target=self.watch_threads) 268 | self._watch_thread.daemon = True 269 | self._watch_thread.start() 270 | 271 | def next_batch(self): 272 | """Return a Batch from the batch queue. 273 | 274 | If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search. 275 | 276 | Returns: 277 | batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset. 278 | """ 279 | # If the batch queue is empty, print a warning 280 | if self._batch_queue.qsize() == 0: 281 | tf.logging.warning( 282 | 'Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', 283 | self._batch_queue.qsize(), self._example_queue.qsize()) 284 | if self._single_pass and self._finished_reading: 285 | tf.logging.info("Finished reading dataset in single_pass mode.") 286 | return None 287 | 288 | batch = self._batch_queue.get() # get the next Batch 289 | return batch 290 | 291 | def fill_example_queue(self): 292 | """Reads data from file and processes into Examples which are then placed into the example queue.""" 293 | 294 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) 295 | 296 | while True: 297 | try: 298 | # (article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings. 299 | (article) = next(input_gen) # read the next example from file. article and abstract are both strings. 300 | except StopIteration: # if there are no more examples: 301 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.") 302 | if self._single_pass: 303 | tf.logging.info( 304 | "single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 305 | self._finished_reading = True 306 | break 307 | else: 308 | raise Exception("single_pass mode is off but the example generator is out of data; error.") 309 | 310 | # abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences. 311 | example = Example(article, article, self._vocab, self._hps) # Process into an Example. 312 | self._example_queue.put(example) # place the Example in the example queue. 313 | 314 | def fill_batch_queue(self): 315 | """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue. 316 | 317 | In decode mode, makes batches that each contain a single example repeated. 318 | """ 319 | while True: 320 | if self._hps.mode != 'decode': 321 | # Get bucketing_cache_size-many batches of Examples into a list, then sort 322 | inputs = [] 323 | for _ in range(self._hps.batch_size * self._bucketing_cache_size): 324 | inputs.append(self._example_queue.get()) 325 | inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence 326 | 327 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. 328 | batches = [] 329 | for i in range(0, len(inputs), self._hps.batch_size): 330 | batches.append(inputs[i:i + self._hps.batch_size]) 331 | if not self._single_pass: 332 | shuffle(batches) 333 | for b in batches: # each b is a list of Example objects 334 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 335 | 336 | else: # beam search decode mode 337 | ex = self._example_queue.get() 338 | b = [ex for _ in range(self._hps.batch_size)] 339 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 340 | 341 | def watch_threads(self): 342 | """Watch example queue and batch queue threads and restart if dead.""" 343 | while True: 344 | time.sleep(60) 345 | for idx, t in enumerate(self._example_q_threads): 346 | if not t.is_alive(): # if the thread is dead 347 | tf.logging.error('Found example queue thread dead. Restarting.') 348 | new_t = Thread(target=self.fill_example_queue) 349 | self._example_q_threads[idx] = new_t 350 | new_t.daemon = True 351 | new_t.start() 352 | for idx, t in enumerate(self._batch_q_threads): 353 | if not t.is_alive(): # if the thread is dead 354 | tf.logging.error('Found batch queue thread dead. Restarting.') 355 | new_t = Thread(target=self.fill_batch_queue) 356 | self._batch_q_threads[idx] = new_t 357 | new_t.daemon = True 358 | new_t.start() 359 | 360 | def text_generator(self, example_generator): 361 | """Generates article and abstract text from tf.Example. 362 | 363 | Args: 364 | example_generator: a generator of tf.Examples from file. See data.example_generator""" 365 | while True: 366 | e = next(example_generator) # e is a tf.Example 367 | try: 368 | article_text = e.features.feature['article'].bytes_list.value[ 369 | 0].decode() # the article text was saved under the key 'article' in the data files 370 | # abstract_text = e.features.feature['abstract'].bytes_list.value[0].decode() # the abstract text was saved under the key 'abstract' in the data files 371 | except ValueError: 372 | tf.logging.error('Failed to get article or abstract from example') 373 | continue 374 | if len(article_text) == 0: # See https://github.com/abisee/pointer-generator/issues/1 375 | tf.logging.warning('Found an example with empty article text. Skipping it.') 376 | else: 377 | # yield (article_text, abstract_text) 378 | yield (article_text) 379 | -------------------------------------------------------------------------------- /core/getpoint/beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding""" 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | import data 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | class Hypothesis(object): 26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis.""" 27 | 28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage): 29 | """Hypothesis constructor. 30 | 31 | Args: 32 | tokens: List of integers. The ids of the tokens that form the summary so far. 33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far. 34 | state: Current state of the decoder, a LSTMStateTuple. 35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far. 36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far. 37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector. 38 | """ 39 | self.tokens = tokens 40 | self.log_probs = log_probs 41 | self.state = state 42 | self.attn_dists = attn_dists 43 | self.p_gens = p_gens 44 | self.coverage = coverage 45 | 46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage): 47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search. 48 | 49 | Args: 50 | token: Integer. Latest token produced by beam search. 51 | log_prob: Float. Log prob of the latest token. 52 | state: Current decoder state, a LSTMStateTuple. 53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length). 54 | p_gen: Generation probability on latest step. Float. 55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage. 56 | Returns: 57 | New Hypothesis for next step. 58 | """ 59 | return Hypothesis(tokens = self.tokens + [token], 60 | log_probs = self.log_probs + [log_prob], 61 | state = state, 62 | attn_dists = self.attn_dists + [attn_dist], 63 | p_gens = self.p_gens + [p_gen], 64 | coverage = coverage) 65 | 66 | @property 67 | def latest_token(self): 68 | return self.tokens[-1] 69 | 70 | @property 71 | def log_prob(self): 72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far 73 | return sum(self.log_probs) 74 | 75 | @property 76 | def avg_log_prob(self): 77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability) 78 | return self.log_prob / len(self.tokens) 79 | 80 | 81 | def run_beam_search(sess, model, vocab, batch): 82 | """Performs beam search decoding on the given example. 83 | 84 | Args: 85 | sess: a tf.Session 86 | model: a seq2seq model 87 | vocab: Vocabulary object 88 | batch: Batch object that is the same example repeated across the batch 89 | 90 | Returns: 91 | best_hyp: Hypothesis object; the best hypothesis found by beam search. 92 | """ 93 | # Run the encoder to get the encoder hidden states and decoder initial state 94 | enc_states, dec_in_state = model.run_encoder(sess, batch) 95 | # dec_in_state is a LSTMStateTuple 96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 97 | 98 | # Initialize beam_size-many hyptheses 99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)], 100 | log_probs=[0.0], 101 | state=dec_in_state, 102 | attn_dists=[], 103 | p_gens=[], 104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length 105 | ) for _ in range(FLAGS.beam_size)] 106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token) 107 | 108 | steps = 0 109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size: 110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis 111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings 112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses 113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None) 114 | 115 | # Run one step of the decoder to get the new info 116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess, 117 | batch=batch, 118 | latest_tokens=latest_tokens, 119 | enc_states=enc_states, 120 | dec_init_states=states, 121 | prev_coverage=prev_coverage) 122 | 123 | # Extend each hypothesis and collect them all in all_hyps 124 | all_hyps = [] 125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. 126 | for i in range(num_orig_hyps): 127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info 128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps: 129 | # Extend the ith hypothesis with the jth option 130 | new_hyp = h.extend(token=topk_ids[i, j], 131 | log_prob=topk_log_probs[i, j], 132 | state=new_state, 133 | attn_dist=attn_dist, 134 | p_gen=p_gen, 135 | coverage=new_coverage_i) 136 | all_hyps.append(new_hyp) 137 | 138 | # Filter and collect any hypotheses that have produced the end token. 139 | hyps = [] # will contain hypotheses for the next step 140 | for h in sort_hyps(all_hyps): # in order of most likely h 141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... 142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 143 | if steps >= FLAGS.min_dec_steps: 144 | results.append(h) 145 | else: # hasn't reached stop token, so continue to extend this hypothesis 146 | hyps.append(h) 147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop. 149 | break 150 | 151 | steps += 1 152 | 153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 154 | 155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 156 | results = hyps 157 | 158 | # Sort hypotheses by average log probability 159 | hyps_sorted = sort_hyps(results) 160 | 161 | # Return the hypothesis with highest average log prob 162 | return hyps_sorted[0] 163 | 164 | def sort_hyps(hyps): 165 | """Return a list of Hypothesis objects, sorted by descending average log probability""" 166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 167 | -------------------------------------------------------------------------------- /core/getpoint/convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import struct 5 | import tensorflow as tf 6 | from tensorflow.core.example import example_pb2 7 | 8 | 9 | END_TOKENS = frozenset(['.', '!', '?', '...', "'", "`", '"', ")"]) # acceptable ways to end a sentence 10 | 11 | def fix_missing_period(line): 12 | """Adds a period to a line that is missing a period""" 13 | if "@highlight" in line: return line 14 | if line=="": return line 15 | if line[-1] in END_TOKENS: return line 16 | return line + " ." 17 | 18 | def get_art_abs(input_string): 19 | lines = input_string.splitlines() 20 | lines = [line.lower() for line in lines] 21 | lines = [fix_missing_period(line) for line in lines] 22 | 23 | article_lines = [] 24 | highlights = [] 25 | next_is_highlight = False 26 | for idx,line in enumerate(lines): 27 | if line == "": 28 | continue # no line 29 | elif line.startswith("@highlight"): 30 | next_is_highlight = True 31 | elif next_is_highlight: 32 | highlights.append(line) 33 | else: 34 | article_lines.append(line) 35 | 36 | # To a string 37 | article = ' '.join(article_lines) 38 | 39 | return article 40 | 41 | def convert_to_bin(input_string, out_file): 42 | 43 | with open(out_file, 'wb') as writer: 44 | # start to write .bin file 45 | article = get_art_abs(input_string) 46 | 47 | article=tf.compat.as_bytes(article, encoding='utf-8') 48 | # tf.Example write 49 | tf_example = example_pb2.Example() 50 | tf_example.features.feature['article'].bytes_list.value.extend([article]) 51 | tf_example_str = tf_example.SerializeToString() 52 | str_len = len(tf_example_str) 53 | writer.write(struct.pack('q', str_len)) 54 | writer.write(struct.pack('%ds' % str_len, tf_example_str)) 55 | -------------------------------------------------------------------------------- /core/getpoint/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" 18 | 19 | import glob 20 | import random 21 | import struct 22 | import csv 23 | from tensorflow.core.example import example_pb2 24 | 25 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 26 | SENTENCE_START = '' 27 | SENTENCE_END = '' 28 | 29 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence # nosec - B105:hardcoded_password_string] 30 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words # nosec - B105:hardcoded_password_string] 31 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 32 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 33 | 34 | 35 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 36 | 37 | 38 | class Vocab(object): 39 | """Vocabulary class for mapping between words and ids (integers)""" 40 | 41 | def __init__(self, vocab_file, max_size): 42 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. 43 | 44 | Args: 45 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. 46 | max_size: integer. The maximum size of the resulting Vocabulary.""" 47 | self._word_to_id = {} 48 | self._id_to_word = {} 49 | self._count = 0 # keeps track of total number of words in the Vocab 50 | 51 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 52 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 53 | self._word_to_id[w] = self._count 54 | self._id_to_word[self._count] = w 55 | self._count += 1 56 | 57 | # Read the vocab file and add words up to max_size 58 | with open(vocab_file, 'r') as vocab_f: 59 | for line in vocab_f: 60 | pieces = line.split() 61 | if len(pieces) != 2: 62 | # print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line) 63 | continue 64 | w = pieces[0] 65 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 66 | raise Exception( 67 | ', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 68 | if w in self._word_to_id: 69 | raise Exception('Duplicated word in vocabulary file: %s' % w) 70 | self._word_to_id[w] = self._count 71 | self._id_to_word[self._count] = w 72 | self._count += 1 73 | if max_size != 0 and self._count >= max_size: 74 | # print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) 75 | break 76 | 77 | # print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])) 78 | 79 | def word2id(self, word): 80 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" 81 | if word not in self._word_to_id: 82 | return self._word_to_id[UNKNOWN_TOKEN] 83 | return self._word_to_id[word] 84 | 85 | def id2word(self, word_id): 86 | """Returns the word (string) corresponding to an id (integer).""" 87 | if word_id not in self._id_to_word: 88 | raise ValueError('Id not found in vocab: %d' % word_id) 89 | return self._id_to_word[word_id] 90 | 91 | def size(self): 92 | """Returns the total size of the vocabulary""" 93 | return self._count 94 | 95 | def write_metadata(self, fpath): 96 | """Writes metadata file for Tensorboard word embedding visualizer as described here: 97 | https://www.tensorflow.org/get_started/embedding_viz 98 | 99 | Args: 100 | fpath: place to write the metadata file 101 | """ 102 | # print("Writing word embedding metadata file to %s..." % (fpath)) 103 | with open(fpath, "w") as f: 104 | fieldnames = ['word'] 105 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 106 | for i in range(self.size()): 107 | writer.writerow({"word": self._id_to_word[i]}) 108 | 109 | 110 | def example_generator(data_path, single_pass): 111 | """Generates tf.Examples from data files. 112 | 113 | Binary data format: . represents the byte size 114 | of . is serialized tf.Example proto. The tf.Example contains 115 | the tokenized article text and summary. 116 | 117 | Args: 118 | data_path: 119 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all. 120 | single_pass: 121 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely. 122 | 123 | Yields: 124 | Deserialized tf.Example. 125 | """ 126 | while True: 127 | filelist = glob.glob(data_path) # get the list of datafiles 128 | if not filelist: 129 | raise ValueError('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 130 | if single_pass: 131 | filelist = sorted(filelist) 132 | else: 133 | random.shuffle(filelist) 134 | for f in filelist: 135 | reader = open(f, 'rb') 136 | while True: 137 | len_bytes = reader.read(8) 138 | if not len_bytes: break # finished reading this file 139 | str_len = struct.unpack('q', len_bytes)[0] 140 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 141 | yield example_pb2.Example.FromString(example_str) 142 | if single_pass: 143 | # print("example_generator completed reading all datafiles. No more data.") 144 | break 145 | 146 | 147 | def article2ids(article_words, vocab): 148 | """Map the article words to their ids. Also return a list of OOVs in the article. 149 | 150 | Args: 151 | article_words: list of words (strings) 152 | vocab: Vocabulary object 153 | 154 | Returns: 155 | ids: 156 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002. 157 | oovs: 158 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers.""" 159 | ids = [] 160 | oovs = [] 161 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 162 | for w in article_words: 163 | i = vocab.word2id(w) 164 | if i == unk_id: # If w is OOV 165 | if w not in oovs: # Add to list of OOVs 166 | oovs.append(w) 167 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 168 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 169 | else: 170 | ids.append(i) 171 | return ids, oovs 172 | 173 | 174 | def abstract2ids(abstract_words, vocab, article_oovs): 175 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers. 176 | 177 | Args: 178 | abstract_words: list of words (strings) 179 | vocab: Vocabulary object 180 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers 181 | 182 | Returns: 183 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id.""" 184 | ids = [] 185 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 186 | for w in abstract_words: 187 | i = vocab.word2id(w) 188 | if i == unk_id: # If w is an OOV word 189 | if w in article_oovs: # If w is an in-article OOV 190 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 191 | ids.append(vocab_idx) 192 | else: # If w is an out-of-article OOV 193 | ids.append(unk_id) # Map to the UNK token id 194 | else: 195 | ids.append(i) 196 | return ids 197 | 198 | 199 | def outputids2words(id_list, vocab, article_oovs): 200 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode). 201 | 202 | Args: 203 | id_list: list of ids (integers) 204 | vocab: Vocabulary object 205 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode) 206 | 207 | Returns: 208 | words: list of words (strings) 209 | """ 210 | words = [] 211 | for i in id_list: 212 | try: 213 | w = vocab.id2word(i) # might be [UNK] 214 | except ValueError as e: # w is OOV 215 | if article_oovs is None: 216 | raise ValueError( 217 | "Error: model produced a word ID that isn't in the vocabulary. This should not happen in" 218 | "baseline (no pointer-generator) mode") 219 | article_oov_idx = i - vocab.size() 220 | try: 221 | w = article_oovs[article_oov_idx] 222 | except ValueError as e: # i doesn't correspond to an article oov 223 | raise ValueError( 224 | 'Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % ( 225 | i, article_oov_idx, len(article_oovs))) 226 | words.append(w) 227 | return words 228 | 229 | 230 | def abstract2sents(abstract): 231 | """Splits abstract text from datafile into list of sentences. 232 | 233 | Args: 234 | abstract: string containing and tags for starts and ends of sentences 235 | 236 | Returns: 237 | sents: List of sentence strings (no tags)""" 238 | cur = 0 239 | sents = [] 240 | while True: 241 | try: 242 | start_p = abstract.index(SENTENCE_START, cur) 243 | end_p = abstract.index(SENTENCE_END, start_p + 1) 244 | cur = end_p + len(SENTENCE_END) 245 | sents.append(abstract[start_p + len(SENTENCE_START):end_p]) 246 | except ValueError as e: # no more sentences 247 | return sents 248 | 249 | 250 | def show_art_oovs(article, vocab): 251 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them""" 252 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 253 | words = article.split(' ') 254 | words = [("__%s__" % w) if vocab.word2id(w) == unk_token else w for w in words] 255 | out_str = ' '.join(words) 256 | return out_str 257 | 258 | 259 | def show_abs_oovs(abstract, vocab, article_oovs): 260 | """Returns the abstract string, highlighting the article OOVs with __underscores__. 261 | 262 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!. 263 | 264 | Args: 265 | abstract: string 266 | vocab: Vocabulary object 267 | article_oovs: list of words (strings), or None (in baseline mode) 268 | """ 269 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 270 | words = abstract.split(' ') 271 | new_words = [] 272 | for w in words: 273 | if vocab.word2id(w) == unk_token: # w is oov 274 | if article_oovs is None: # baseline mode 275 | new_words.append("__%s__" % w) 276 | else: # pointer-generator mode 277 | if w in article_oovs: 278 | new_words.append("__%s__" % w) 279 | else: 280 | new_words.append("!!__%s__!!" % w) 281 | else: # w is in-vocab word 282 | new_words.append(w) 283 | out_str = ' '.join(new_words) 284 | return out_str 285 | -------------------------------------------------------------------------------- /core/getpoint/decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis""" 18 | 19 | import os 20 | import sys 21 | import time 22 | import tensorflow as tf 23 | import beam_search 24 | import data 25 | import json 26 | import pyrouge 27 | import util 28 | import logging 29 | import numpy as np 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint 34 | 35 | 36 | class BeamSearchDecoder(object): 37 | """Beam search decoder.""" 38 | 39 | def __init__(self, model, batcher, vocab): 40 | """Initialize decoder. 41 | 42 | Args: 43 | model: a Seq2SeqAttentionModel object. 44 | batcher: a Batcher object. 45 | vocab: Vocabulary object 46 | """ 47 | self._model = model 48 | self._model.build_graph() 49 | self._batcher = batcher 50 | self._vocab = vocab 51 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding 52 | self._sess = tf.Session(config=util.get_config()) 53 | 54 | # Load an initial checkpoint to use for decoding 55 | ckpt_path = util.load_ckpt(self._saver, self._sess) 56 | 57 | 58 | # if FLAGS.single_pass: 59 | # # Make a descriptive decode directory name 60 | # ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456" 61 | # self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) 62 | # if os.path.exists(self._decode_dir): 63 | # raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir) 64 | # 65 | # else: # Generic decode dir name 66 | self._decode_dir = os.path.join(FLAGS.log_root, "decode") 67 | 68 | # Make the decode dir if necessary 69 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) 70 | 71 | # if FLAGS.single_pass: 72 | # # Make the dirs to contain output written in the correct format for pyrouge 73 | # self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") 74 | # if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) 75 | # self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") 76 | # if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir) 77 | 78 | 79 | def decode(self): 80 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals""" 81 | # t0 = time.time() 82 | batch = self._batcher.next_batch() # 1 example repeated across batch 83 | 84 | original_article = batch.original_articles[0] # string 85 | original_abstract = batch.original_abstracts[0] # string 86 | 87 | # input data 88 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string 89 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string 90 | 91 | # Run beam search to get best Hypothesis 92 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) 93 | 94 | # Extract the output ids from the hypothesis and convert back to words 95 | output_ids = [int(t) for t in best_hyp.tokens[1:]] 96 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) 97 | 98 | # Remove the [STOP] token from decoded_words, if necessary 99 | try: 100 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol 101 | decoded_words = decoded_words[:fst_stop_idx] 102 | except ValueError: 103 | decoded_words = decoded_words 104 | decoded_output = ' '.join(decoded_words) # single string 105 | 106 | # tf.logging.info('ARTICLE: %s', article) 107 | # tf.logging.info('GENERATED SUMMARY: %s', decoded_output) 108 | 109 | sys.stdout.write(decoded_output) 110 | -------------------------------------------------------------------------------- /core/getpoint/inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this: 3 | python inspect_checkpoint.py model.12345 4 | """ 5 | 6 | import tensorflow as tf 7 | import sys 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | if len(sys.argv) != 2: 13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.") 14 | file_name = sys.argv[1] 15 | reader = tf.train.NewCheckpointReader(file_name) 16 | var_to_shape_map = reader.get_variable_to_shape_map() 17 | 18 | finite = [] 19 | all_infnan = [] 20 | some_infnan = [] 21 | 22 | for key in sorted(var_to_shape_map.keys()): 23 | tensor = reader.get_tensor(key) 24 | if np.all(np.isfinite(tensor)): 25 | finite.append(key) 26 | else: 27 | if not np.any(np.isfinite(tensor)): 28 | all_infnan.append(key) 29 | else: 30 | some_infnan.append(key) 31 | 32 | print("\nFINITE VARIABLES:") 33 | for key in finite: print(key) 34 | 35 | print("\nVARIABLES THAT ARE ALL INF/NAN:") 36 | for key in all_infnan: print(key) 37 | 38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:") 39 | for key in some_infnan: print(key) 40 | 41 | if not all_infnan and not some_infnan: 42 | print("CHECK PASSED: checkpoint contains no inf/NaN values") 43 | else: 44 | print("CHECK FAILED: checkpoint contains some inf/NaN values") 45 | -------------------------------------------------------------------------------- /core/getpoint/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to build and run the tensorflow graph for the sequence-to-sequence model""" 18 | 19 | import os 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | from attention_decoder import attention_decoder 24 | from tensorflow.contrib.tensorboard.plugins import projector 25 | 26 | FLAGS = tf.app.flags.FLAGS 27 | 28 | 29 | class SummarizationModel(object): 30 | """A class to represent a sequence-to-sequence model for text summarization. Supports both baseline mode, pointer-generator mode, and coverage""" 31 | 32 | def __init__(self, hps, vocab): 33 | self._hps = hps 34 | self._vocab = vocab 35 | 36 | def _add_placeholders(self): 37 | """Add placeholders to the graph. These are entry points for any input data.""" 38 | hps = self._hps 39 | 40 | # encoder part 41 | self._enc_batch = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch') 42 | self._enc_lens = tf.placeholder(tf.int32, [hps.batch_size], name='enc_lens') 43 | self._enc_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, None], name='enc_padding_mask') 44 | if FLAGS.pointer_gen: 45 | self._enc_batch_extend_vocab = tf.placeholder(tf.int32, [hps.batch_size, None], 46 | name='enc_batch_extend_vocab') 47 | self._max_art_oovs = tf.placeholder(tf.int32, [], name='max_art_oovs') 48 | 49 | # decoder part 50 | self._dec_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='dec_batch') 51 | self._target_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='target_batch') 52 | self._dec_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, hps.max_dec_steps], 53 | name='dec_padding_mask') 54 | 55 | if hps.mode == "decode" and hps.coverage: 56 | self.prev_coverage = tf.placeholder(tf.float32, [hps.batch_size, None], name='prev_coverage') 57 | 58 | def _make_feed_dict(self, batch, just_enc=False): 59 | """Make a feed dictionary mapping parts of the batch to the appropriate placeholders. 60 | 61 | Args: 62 | batch: Batch object 63 | just_enc: Boolean. If True, only feed the parts needed for the encoder. 64 | """ 65 | feed_dict = {} 66 | feed_dict[self._enc_batch] = batch.enc_batch 67 | feed_dict[self._enc_lens] = batch.enc_lens 68 | feed_dict[self._enc_padding_mask] = batch.enc_padding_mask 69 | if FLAGS.pointer_gen: 70 | feed_dict[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab 71 | feed_dict[self._max_art_oovs] = batch.max_art_oovs 72 | if not just_enc: 73 | feed_dict[self._dec_batch] = batch.dec_batch 74 | feed_dict[self._target_batch] = batch.target_batch 75 | feed_dict[self._dec_padding_mask] = batch.dec_padding_mask 76 | return feed_dict 77 | 78 | def _add_encoder(self, encoder_inputs, seq_len): 79 | """Add a single-layer bidirectional LSTM encoder to the graph. 80 | 81 | Args: 82 | encoder_inputs: A tensor of shape [batch_size, <=max_enc_steps, emb_size]. 83 | seq_len: Lengths of encoder_inputs (before padding). A tensor of shape [batch_size]. 84 | 85 | Returns: 86 | encoder_outputs: 87 | A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. It's 2*hidden_dim because it's the concatenation of the forwards and backwards states. 88 | fw_state, bw_state: 89 | Each are LSTMStateTuples of shape ([batch_size,hidden_dim],[batch_size,hidden_dim]) 90 | """ 91 | with tf.variable_scope('encoder'): 92 | cell_fw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, 93 | state_is_tuple=True) 94 | cell_bw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init, 95 | state_is_tuple=True) 96 | (encoder_outputs, (fw_st, bw_st)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_inputs, 97 | dtype=tf.float32, 98 | sequence_length=seq_len, 99 | swap_memory=True) 100 | encoder_outputs = tf.concat(axis=2, values=encoder_outputs) # concatenate the forwards and backwards states 101 | return encoder_outputs, fw_st, bw_st 102 | 103 | def _reduce_states(self, fw_st, bw_st): 104 | """Add to the graph a linear layer to reduce the encoder's final FW and BW state into a single initial state for the decoder. This is needed because the encoder is bidirectional but the decoder is not. 105 | 106 | Args: 107 | fw_st: LSTMStateTuple with hidden_dim units. 108 | bw_st: LSTMStateTuple with hidden_dim units. 109 | 110 | Returns: 111 | state: LSTMStateTuple with hidden_dim units. 112 | """ 113 | hidden_dim = self._hps.hidden_dim 114 | with tf.variable_scope('reduce_final_st'): 115 | # Define weights and biases to reduce the cell and reduce the state 116 | w_reduce_c = tf.get_variable('w_reduce_c', [hidden_dim * 2, hidden_dim], dtype=tf.float32, 117 | initializer=self.trunc_norm_init) 118 | w_reduce_h = tf.get_variable('w_reduce_h', [hidden_dim * 2, hidden_dim], dtype=tf.float32, 119 | initializer=self.trunc_norm_init) 120 | bias_reduce_c = tf.get_variable('bias_reduce_c', [hidden_dim], dtype=tf.float32, 121 | initializer=self.trunc_norm_init) 122 | bias_reduce_h = tf.get_variable('bias_reduce_h', [hidden_dim], dtype=tf.float32, 123 | initializer=self.trunc_norm_init) 124 | 125 | # Apply linear layer 126 | old_c = tf.concat(axis=1, values=[fw_st.c, bw_st.c]) # Concatenation of fw and bw cell 127 | old_h = tf.concat(axis=1, values=[fw_st.h, bw_st.h]) # Concatenation of fw and bw state 128 | new_c = tf.nn.relu(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) # Get new cell from old cell 129 | new_h = tf.nn.relu(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) # Get new state from old state 130 | return tf.contrib.rnn.LSTMStateTuple(new_c, new_h) # Return new cell and state 131 | 132 | def _add_decoder(self, inputs): 133 | """Add attention decoder to the graph. In train or eval mode, you call this once to get output on ALL steps. In decode (beam search) mode, you call this once for EACH decoder step. 134 | 135 | Args: 136 | inputs: inputs to the decoder (word embeddings). A list of tensors shape (batch_size, emb_dim) 137 | 138 | Returns: 139 | outputs: List of tensors; the outputs of the decoder 140 | out_state: The final state of the decoder 141 | attn_dists: A list of tensors; the attention distributions 142 | p_gens: A list of scalar tensors; the generation probabilities 143 | coverage: A tensor, the current coverage vector 144 | """ 145 | hps = self._hps 146 | cell = tf.contrib.rnn.LSTMCell(hps.hidden_dim, state_is_tuple=True, initializer=self.rand_unif_init) 147 | 148 | prev_coverage = self.prev_coverage if hps.mode == "decode" and hps.coverage else None # In decode mode, we run attention_decoder one step at a time and so need to pass in the previous step's coverage vector each time 149 | 150 | outputs, out_state, attn_dists, p_gens, coverage = attention_decoder(inputs, self._dec_in_state, 151 | self._enc_states, self._enc_padding_mask, 152 | cell, initial_state_attention=( 153 | hps.mode == "decode"), pointer_gen=hps.pointer_gen, use_coverage=hps.coverage, 154 | prev_coverage=prev_coverage) 155 | 156 | return outputs, out_state, attn_dists, p_gens, coverage 157 | 158 | def _calc_final_dist(self, vocab_dists, attn_dists): 159 | """Calculate the final distribution, for the pointer-generator model 160 | 161 | Args: 162 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 163 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays 164 | 165 | Returns: 166 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays. 167 | """ 168 | with tf.variable_scope('final_distribution'): 169 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen) 170 | vocab_dists = [p_gen * dist for (p_gen, dist) in zip(self.p_gens, vocab_dists)] 171 | attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(self.p_gens, attn_dists)] 172 | 173 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words 174 | extended_vsize = self._vocab.size() + self._max_art_oovs # the maximum (over the batch) size of the extended vocabulary 175 | extra_zeros = tf.zeros((self._hps.batch_size, self._max_art_oovs)) 176 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in 177 | vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize) 178 | 179 | # Project the values in the attention distributions onto the appropriate entries in the final distributions 180 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution 181 | # This is done for each decoder timestep. 182 | # This is fiddly; we use tf.scatter_nd to do the projection 183 | batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size) 184 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1) 185 | attn_len = tf.shape(self._enc_batch_extend_vocab)[1] # number of states we attend over 186 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len) 187 | indices = tf.stack((batch_nums, self._enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2) 188 | shape = [self._hps.batch_size, extended_vsize] 189 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in 190 | attn_dists] # list length max_dec_steps (batch_size, extended_vsize) 191 | 192 | # Add the vocab distributions and the copy distributions together to get the final distributions 193 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep 194 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore. 195 | final_dists = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in 196 | zip(vocab_dists_extended, attn_dists_projected)] 197 | 198 | return final_dists 199 | 200 | def _add_emb_vis(self, embedding_var): 201 | """Do setup so that we can view word embedding visualization in Tensorboard, as described here: 202 | https://www.tensorflow.org/get_started/embedding_viz 203 | Make the vocab metadata file, then make the projector config file pointing to it.""" 204 | train_dir = os.path.join(FLAGS.log_root, "train") 205 | vocab_metadata_path = os.path.join(train_dir, "vocab_metadata.tsv") 206 | self._vocab.write_metadata(vocab_metadata_path) # write metadata file 207 | summary_writer = tf.summary.FileWriter(train_dir) 208 | config = projector.ProjectorConfig() 209 | embedding = config.embeddings.add() 210 | embedding.tensor_name = embedding_var.name 211 | embedding.metadata_path = vocab_metadata_path 212 | projector.visualize_embeddings(summary_writer, config) 213 | 214 | def _add_seq2seq(self): 215 | """Add the whole sequence-to-sequence model to the graph.""" 216 | hps = self._hps 217 | vsize = self._vocab.size() # size of the vocabulary 218 | 219 | with tf.variable_scope('seq2seq'): 220 | # Some initializers 221 | self.rand_unif_init = tf.random_uniform_initializer(-hps.rand_unif_init_mag, hps.rand_unif_init_mag, 222 | seed=123) 223 | self.trunc_norm_init = tf.truncated_normal_initializer(stddev=hps.trunc_norm_init_std) 224 | 225 | # Add embedding matrix (shared by the encoder and decoder inputs) 226 | with tf.variable_scope('embedding'): 227 | embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32, 228 | initializer=self.trunc_norm_init) 229 | if hps.mode == "train": self._add_emb_vis(embedding) # add to tensorboard 230 | emb_enc_inputs = tf.nn.embedding_lookup(embedding, 231 | self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size) 232 | emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch, 233 | axis=1)] # list length max_dec_steps containing shape (batch_size, emb_size) 234 | 235 | # Add the encoder. 236 | enc_outputs, fw_st, bw_st = self._add_encoder(emb_enc_inputs, self._enc_lens) 237 | self._enc_states = enc_outputs 238 | 239 | # Our encoder is bidirectional and our decoder is unidirectional so we need to reduce the final encoder hidden state to the right size to be the initial decoder hidden state 240 | self._dec_in_state = self._reduce_states(fw_st, bw_st) 241 | 242 | # Add the decoder. 243 | with tf.variable_scope('decoder'): 244 | decoder_outputs, self._dec_out_state, self.attn_dists, self.p_gens, self.coverage = self._add_decoder( 245 | emb_dec_inputs) 246 | 247 | # Add the output projection to obtain the vocabulary distribution 248 | with tf.variable_scope('output_projection'): 249 | w = tf.get_variable('w', [hps.hidden_dim, vsize], dtype=tf.float32, initializer=self.trunc_norm_init) 250 | w_t = tf.transpose(w) 251 | v = tf.get_variable('v', [vsize], dtype=tf.float32, initializer=self.trunc_norm_init) 252 | vocab_scores = [] # vocab_scores is the vocabulary distribution before applying softmax. Each entry on the list corresponds to one decoder step 253 | for i, output in enumerate(decoder_outputs): 254 | if i > 0: 255 | tf.get_variable_scope().reuse_variables() 256 | vocab_scores.append(tf.nn.xw_plus_b(output, w, v)) # apply the linear layer 257 | 258 | vocab_dists = [tf.nn.softmax(s) for s in 259 | vocab_scores] # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file. 260 | 261 | # For pointer-generator model, calc final distribution from copy distribution and vocabulary distribution 262 | if FLAGS.pointer_gen: 263 | final_dists = self._calc_final_dist(vocab_dists, self.attn_dists) 264 | else: # final distribution is just vocabulary distribution 265 | final_dists = vocab_dists 266 | 267 | if hps.mode in ['train', 'eval']: 268 | # Calculate the loss 269 | with tf.variable_scope('loss'): 270 | if FLAGS.pointer_gen: 271 | # Calculate the loss per step 272 | # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words 273 | loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size) 274 | batch_nums = tf.range(0, limit=hps.batch_size) # shape (batch_size) 275 | for dec_step, dist in enumerate(final_dists): 276 | targets = self._target_batch[:, 277 | dec_step] # The indices of the target words. shape (batch_size) 278 | indices = tf.stack((batch_nums, targets), axis=1) # shape (batch_size, 2) 279 | gold_probs = tf.gather_nd(dist, 280 | indices) # shape (batch_size). prob of correct words on this step 281 | losses = -tf.log(gold_probs) 282 | loss_per_step.append(losses) 283 | 284 | # Apply dec_padding_mask and get loss 285 | self._loss = _mask_and_avg(loss_per_step, self._dec_padding_mask) 286 | 287 | else: # baseline model 288 | self._loss = tf.contrib.seq2seq.sequence_loss(tf.stack(vocab_scores, axis=1), 289 | self._target_batch, 290 | self._dec_padding_mask) # this applies softmax internally 291 | 292 | tf.summary.scalar('loss', self._loss) 293 | 294 | # Calculate coverage loss from the attention distributions 295 | if hps.coverage: 296 | with tf.variable_scope('coverage_loss'): 297 | self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask) 298 | tf.summary.scalar('coverage_loss', self._coverage_loss) 299 | self._total_loss = self._loss + hps.cov_loss_wt * self._coverage_loss 300 | tf.summary.scalar('total_loss', self._total_loss) 301 | 302 | if hps.mode == "decode": 303 | # We run decode beam search mode one decoder step at a time 304 | if len(final_dists) != 1: # final_dists is a singleton list containing shape (batch_size, extended_vsize) 305 | raise ValueError("final_dists should be a singleton list containing shape") 306 | final_dists = final_dists[0] 307 | topk_probs, self._topk_ids = tf.nn.top_k(final_dists, 308 | hps.batch_size * 2) # take the k largest probs. note batch_size=beam_size in decode mode 309 | self._topk_log_probs = tf.log(topk_probs) 310 | 311 | def _add_train_op(self): 312 | """Sets self._train_op, the op to run for training.""" 313 | # Take gradients of the trainable variables w.r.t. the loss function to minimize 314 | loss_to_minimize = self._total_loss if self._hps.coverage else self._loss 315 | tvars = tf.trainable_variables() 316 | gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE) 317 | 318 | # Clip the gradients 319 | with tf.device("/gpu:0"): 320 | grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm) 321 | 322 | # Add a summary 323 | tf.summary.scalar('global_norm', global_norm) 324 | 325 | # Apply adagrad optimizer 326 | optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc) 327 | with tf.device("/gpu:0"): 328 | self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, 329 | name='train_step') 330 | 331 | def build_graph(self): 332 | """Add the placeholders, model, global step, train_op and summaries to the graph""" 333 | tf.logging.info('Building graph...') 334 | t0 = time.time() 335 | self._add_placeholders() 336 | with tf.device("/gpu:0"): 337 | self._add_seq2seq() 338 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 339 | if self._hps.mode == 'train': 340 | self._add_train_op() 341 | self._summaries = tf.summary.merge_all() 342 | t1 = time.time() 343 | tf.logging.info('Time to build graph: %i seconds', t1 - t0) 344 | 345 | def run_train_step(self, sess, batch): 346 | """Runs one training iteration. Returns a dictionary containing train op, summaries, loss, global_step and (optionally) coverage loss.""" 347 | feed_dict = self._make_feed_dict(batch) 348 | to_return = { 349 | 'train_op': self._train_op, 350 | 'summaries': self._summaries, 351 | 'loss': self._loss, 352 | 'global_step': self.global_step, 353 | } 354 | if self._hps.coverage: 355 | to_return['coverage_loss'] = self._coverage_loss 356 | return sess.run(to_return, feed_dict) 357 | 358 | def run_eval_step(self, sess, batch): 359 | """Runs one evaluation iteration. Returns a dictionary containing summaries, loss, global_step and (optionally) coverage loss.""" 360 | feed_dict = self._make_feed_dict(batch) 361 | to_return = { 362 | 'summaries': self._summaries, 363 | 'loss': self._loss, 364 | 'global_step': self.global_step, 365 | } 366 | if self._hps.coverage: 367 | to_return['coverage_loss'] = self._coverage_loss 368 | return sess.run(to_return, feed_dict) 369 | 370 | def run_encoder(self, sess, batch): 371 | """For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state. 372 | 373 | Args: 374 | sess: Tensorflow session. 375 | batch: Batch object that is the same example repeated across the batch (for beam search) 376 | 377 | Returns: 378 | enc_states: The encoder states. A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 379 | dec_in_state: A LSTMStateTuple of shape ([1,hidden_dim],[1,hidden_dim]) 380 | """ 381 | feed_dict = self._make_feed_dict(batch, just_enc=True) # feed the batch into the placeholders 382 | (enc_states, dec_in_state, global_step) = sess.run([self._enc_states, self._dec_in_state, self.global_step], 383 | feed_dict) # run the encoder 384 | 385 | # dec_in_state is LSTMStateTuple shape ([batch_size,hidden_dim],[batch_size,hidden_dim]) 386 | # Given that the batch is a single example repeated, dec_in_state is identical across the batch so we just take the top row. 387 | dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0], dec_in_state.h[0]) 388 | return enc_states, dec_in_state 389 | 390 | def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states, prev_coverage): 391 | """For beam search decoding. Run the decoder for one step. 392 | 393 | Args: 394 | sess: Tensorflow session. 395 | batch: Batch object containing single example repeated across the batch 396 | latest_tokens: Tokens to be fed as input into the decoder for this timestep 397 | enc_states: The encoder states. 398 | dec_init_states: List of beam_size LSTMStateTuples; the decoder states from the previous timestep 399 | prev_coverage: List of np arrays. The coverage vectors from the previous timestep. List of None if not using coverage. 400 | 401 | Returns: 402 | ids: top 2k ids. shape [beam_size, 2*beam_size] 403 | probs: top 2k log probabilities. shape [beam_size, 2*beam_size] 404 | new_states: new states of the decoder. a list length beam_size containing 405 | LSTMStateTuples each of shape ([hidden_dim,],[hidden_dim,]) 406 | attn_dists: List length beam_size containing lists length attn_length. 407 | p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode. 408 | new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on. 409 | """ 410 | 411 | beam_size = len(dec_init_states) 412 | 413 | # Turn dec_init_states (a list of LSTMStateTuples) into a single LSTMStateTuple for the batch 414 | cells = [np.expand_dims(state.c, axis=0) for state in dec_init_states] 415 | hiddens = [np.expand_dims(state.h, axis=0) for state in dec_init_states] 416 | new_c = np.concatenate(cells, axis=0) # shape [batch_size,hidden_dim] 417 | new_h = np.concatenate(hiddens, axis=0) # shape [batch_size,hidden_dim] 418 | new_dec_in_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 419 | 420 | feed = { 421 | self._enc_states: enc_states, 422 | self._enc_padding_mask: batch.enc_padding_mask, 423 | self._dec_in_state: new_dec_in_state, 424 | self._dec_batch: np.transpose(np.array([latest_tokens])), 425 | } 426 | 427 | to_return = { 428 | "ids": self._topk_ids, 429 | "probs": self._topk_log_probs, 430 | "states": self._dec_out_state, 431 | "attn_dists": self.attn_dists 432 | } 433 | 434 | if FLAGS.pointer_gen: 435 | feed[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab 436 | feed[self._max_art_oovs] = batch.max_art_oovs 437 | to_return['p_gens'] = self.p_gens 438 | 439 | if self._hps.coverage: 440 | feed[self.prev_coverage] = np.stack(prev_coverage, axis=0) 441 | to_return['coverage'] = self.coverage 442 | 443 | results = sess.run(to_return, feed_dict=feed) # run the decoder step 444 | 445 | # Convert results['states'] (a single LSTMStateTuple) into a list of LSTMStateTuple -- one for each hypothesis 446 | new_states = [tf.contrib.rnn.LSTMStateTuple(results['states'].c[i, :], results['states'].h[i, :]) for i in 447 | range(beam_size)] 448 | 449 | # Convert singleton list containing a tensor to a list of k arrays 450 | if len(results['attn_dists']) != 1: 451 | raise ValueError("len(results['attn_dists']) != 1") 452 | attn_dists = results['attn_dists'][0].tolist() 453 | 454 | if FLAGS.pointer_gen: 455 | # Convert singleton list containing a tensor to a list of k arrays 456 | if len(results['p_gens']) != 1: 457 | raise ValueError("len(results['p_gens']) != 1") 458 | p_gens = results['p_gens'][0].tolist() 459 | else: 460 | p_gens = [None for _ in range(beam_size)] 461 | 462 | # Convert the coverage tensor to a list length k containing the coverage vector for each hypothesis 463 | if FLAGS.coverage: 464 | new_coverage = results['coverage'].tolist() 465 | if len(new_coverage) != beam_size: 466 | raise ValueError("len(new_coverage) != beam_size") 467 | else: 468 | new_coverage = [None for _ in range(beam_size)] 469 | 470 | return results['ids'], results['probs'], new_states, attn_dists, p_gens, new_coverage 471 | 472 | 473 | def _mask_and_avg(values, padding_mask): 474 | """Applies mask to values then returns overall average (a scalar) 475 | 476 | Args: 477 | values: a list length max_dec_steps containing arrays shape (batch_size). 478 | padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s. 479 | 480 | Returns: 481 | a scalar 482 | """ 483 | 484 | dec_lens = tf.reduce_sum(padding_mask, axis=1) # shape batch_size. float32 485 | values_per_step = [v * padding_mask[:, dec_step] for dec_step, v in enumerate(values)] 486 | values_per_ex = sum(values_per_step) / dec_lens # shape (batch_size); normalized value for each batch member 487 | return tf.reduce_mean(values_per_ex) # overall average 488 | 489 | 490 | def _coverage_loss(attn_dists, padding_mask): 491 | """Calculates the coverage loss from the attention distributions. 492 | 493 | Args: 494 | attn_dists: The attention distributions for each decoder timestep. A list length max_dec_steps containing shape (batch_size, attn_length) 495 | padding_mask: shape (batch_size, max_dec_steps). 496 | 497 | Returns: 498 | coverage_loss: scalar 499 | """ 500 | coverage = tf.zeros_like(attn_dists[0]) # shape (batch_size, attn_length). Initial coverage is zero. 501 | covlosses = [] # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size). 502 | for a in attn_dists: 503 | covloss = tf.reduce_sum(tf.minimum(a, coverage), [1]) # calculate the coverage loss for this step 504 | covlosses.append(covloss) 505 | coverage += a # update the coverage vector 506 | coverage_loss = _mask_and_avg(covlosses, padding_mask) 507 | return coverage_loss 508 | -------------------------------------------------------------------------------- /core/getpoint/run_summarization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This is the top-level file to train, evaluate or test your summarization model""" 18 | 19 | import sys 20 | import time 21 | import os 22 | import tensorflow as tf 23 | import numpy as np 24 | from collections import namedtuple 25 | from data import Vocab 26 | from batcher import Batcher 27 | from model import SummarizationModel 28 | from decode import BeamSearchDecoder 29 | import util 30 | from tensorflow.python import debug as tf_debug 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | # Where to find data 35 | tf.app.flags.DEFINE_string('data_path', '', 36 | 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.') 37 | tf.app.flags.DEFINE_string('ckpt_dir', '', 'Directory that contains the ckpt files.') 38 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.') 39 | 40 | # Important settings 41 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode') 42 | tf.app.flags.DEFINE_boolean('single_pass', False, 43 | 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.') 44 | 45 | # Where to save output 46 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.') 47 | tf.app.flags.DEFINE_string('exp_name', '', 48 | 'Name for experiment. Logs will be saved in a directory with this name, under log_root.') 49 | 50 | # Hyperparameters 51 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states') 52 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings') 53 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size') 54 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)') 55 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)') 56 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.') 57 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 58 | 'Minimum sequence length of generated summary. Applies only for beam search decoding mode') 59 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 60 | 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') 61 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate') 62 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad') 63 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization') 64 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else') 65 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping') 66 | 67 | # Pointer-generator or baseline model 68 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.') 69 | 70 | # Coverage hyperparameters 71 | tf.app.flags.DEFINE_boolean('coverage', False, 72 | 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.') 73 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 74 | 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.') 75 | 76 | # Utility flags, for restoring and changing checkpoints 77 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 78 | 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.') 79 | tf.app.flags.DEFINE_boolean('restore_best_model', False, 80 | 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.') 81 | 82 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger 83 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)") 84 | 85 | 86 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 87 | """Calculate the running average loss via exponential decay. 88 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve. 89 | 90 | Args: 91 | loss: loss on the most recent eval step 92 | running_avg_loss: running_avg_loss so far 93 | summary_writer: FileWriter object to write for tensorboard 94 | step: training iteration step 95 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother. 96 | 97 | Returns: 98 | running_avg_loss: new running average loss 99 | """ 100 | if running_avg_loss == 0: # on the first iteration just take the loss 101 | running_avg_loss = loss 102 | else: 103 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 104 | running_avg_loss = min(running_avg_loss, 12) # clip 105 | loss_sum = tf.Summary() 106 | tag_name = 'running_avg_loss/decay=%f' % (decay) 107 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 108 | summary_writer.add_summary(loss_sum, step) 109 | tf.logging.info('running_avg_loss: %f', running_avg_loss) 110 | return running_avg_loss 111 | 112 | 113 | def restore_best_model(): 114 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" 115 | tf.logging.info("Restoring best model for training...") 116 | 117 | # Initialize all vars in the model 118 | sess = tf.Session(config=util.get_config()) 119 | print("Initializing all variables...") 120 | sess.run(tf.initialize_all_variables()) 121 | 122 | # Restore the best model from eval dir 123 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name]) 124 | print("Restoring all non-adagrad variables from best model in eval dir...") 125 | curr_ckpt = util.load_ckpt(saver, sess, "eval") 126 | print("Restored %s." % curr_ckpt) 127 | 128 | # Save this model to train dir and quit 129 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") 130 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) 131 | print("Saving model to %s..." % new_fname) 132 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables 133 | new_saver.save(sess, new_fname) 134 | print("Saved.") 135 | exit() 136 | 137 | 138 | def convert_to_coverage_model(): 139 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" 140 | tf.logging.info("converting non-coverage model to coverage model..") 141 | 142 | # initialize an entire coverage model from scratch 143 | sess = tf.Session(config=util.get_config()) 144 | print("initializing everything...") 145 | sess.run(tf.global_variables_initializer()) 146 | 147 | # load all non-coverage weights from checkpoint 148 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) 149 | print("restoring non-coverage variables...") 150 | curr_ckpt = util.load_ckpt(saver, sess, FLAGS.ckpt_dir) 151 | print("restored.") 152 | 153 | # save this model and quit 154 | new_fname = curr_ckpt + '_cov_init' 155 | print("saving model to %s..." % new_fname) 156 | new_saver = tf.train.Saver() # this one will save all variables that now exist 157 | new_saver.save(sess, new_fname) 158 | print("saved.") 159 | exit() 160 | 161 | 162 | def setup_training(model, batcher): 163 | """Does setup before starting training (run_training)""" 164 | train_dir = os.path.join(FLAGS.log_root, "train") 165 | if not os.path.exists(train_dir): 166 | os.makedirs(train_dir) 167 | 168 | model.build_graph() # build the graph 169 | if FLAGS.convert_to_coverage_model: 170 | if not FLAGS.coverage: 171 | raise ValueError( 172 | "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True" 173 | "and coverage=True") 174 | convert_to_coverage_model() 175 | if FLAGS.restore_best_model: 176 | restore_best_model() 177 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time 178 | 179 | sv = tf.train.Supervisor(logdir=train_dir, 180 | is_chief=True, 181 | saver=saver, 182 | summary_op=None, 183 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs 184 | save_model_secs=60, # checkpoint every 60 secs 185 | global_step=model.global_step) 186 | summary_writer = sv.summary_writer 187 | tf.logging.info("Preparing or waiting for session...") 188 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config()) 189 | tf.logging.info("Created session.") 190 | try: 191 | run_training(model, batcher, sess_context_manager, sv, 192 | summary_writer) # this is an infinite loop until interrupted 193 | except KeyboardInterrupt: 194 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") 195 | sv.stop() 196 | 197 | 198 | def run_training(model, batcher, sess_context_manager, sv, summary_writer): 199 | """Repeatedly runs training iterations, logging loss to screen and writing summaries""" 200 | tf.logging.info("starting run_training") 201 | with sess_context_manager as sess: 202 | if FLAGS.debug: # start the tensorflow debugger 203 | sess = tf_debug.LocalCLIDebugWrapperSession(sess) 204 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 205 | while True: # repeats until interrupted 206 | batch = batcher.next_batch() 207 | 208 | tf.logging.info('running training step...') 209 | t0 = time.time() 210 | results = model.run_train_step(sess, batch) 211 | t1 = time.time() 212 | tf.logging.info('seconds for training step: %.3f', t1 - t0) 213 | 214 | loss = results['loss'] 215 | tf.logging.info('loss: %f', loss) # print the loss to screen 216 | 217 | if loss < 5: 218 | break 219 | 220 | if not np.isfinite(loss): 221 | raise Exception("Loss is not finite. Stopping.") 222 | 223 | if FLAGS.coverage: 224 | coverage_loss = results['coverage_loss'] 225 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen 226 | 227 | # get the summaries and iteration number so we can write summaries to tensorboard 228 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer 229 | train_step = results['global_step'] # we need this to update our running average loss 230 | 231 | summary_writer.add_summary(summaries, train_step) # write the summaries 232 | if train_step % 100 == 0: # flush the summary writer every so often 233 | summary_writer.flush() 234 | 235 | 236 | def run_eval(model, batcher, vocab): 237 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" 238 | model.build_graph() # build the graph 239 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time 240 | sess = tf.Session(config=util.get_config()) 241 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data 242 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved 243 | summary_writer = tf.summary.FileWriter(eval_dir) 244 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping 245 | best_loss = None # will hold the best loss achieved so far 246 | 247 | while True: 248 | _ = util.load_ckpt(saver, sess, FLAGS.ckpt_dir) # load a new checkpoint 249 | batch = batcher.next_batch() # get the next batch 250 | 251 | # run eval on the batch 252 | t0 = time.time() 253 | results = model.run_eval_step(sess, batch) 254 | t1 = time.time() 255 | tf.logging.info('seconds for batch: %.2f', t1 - t0) 256 | 257 | # print the loss and coverage loss to screen 258 | loss = results['loss'] 259 | tf.logging.info('loss: %f', loss) 260 | if FLAGS.coverage: 261 | coverage_loss = results['coverage_loss'] 262 | tf.logging.info("coverage_loss: %f", coverage_loss) 263 | 264 | # add summaries 265 | summaries = results['summaries'] 266 | train_step = results['global_step'] 267 | summary_writer.add_summary(summaries, train_step) 268 | 269 | # calculate running avg loss 270 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step) 271 | 272 | # If running_avg_loss is best so far, save this checkpoint (early stopping). 273 | # These checkpoints will appear as bestmodel- in the eval dir 274 | if best_loss is None or running_avg_loss < best_loss: 275 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, 276 | bestmodel_save_path) 277 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') 278 | best_loss = running_avg_loss 279 | 280 | # flush the summary writer every so often 281 | if train_step % 100 == 0: 282 | summary_writer.flush() 283 | 284 | 285 | def main(unused_argv): 286 | if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly 287 | raise Exception("Problem with flags: %s" % unused_argv) 288 | 289 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want 290 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) 291 | 292 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary 293 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) 294 | if not os.path.exists(FLAGS.log_root): 295 | if FLAGS.mode == "train": 296 | os.makedirs(FLAGS.log_root) 297 | else: 298 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) 299 | 300 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary 301 | 302 | # If in decode mode, set batch_size = beam_size 303 | # Reason: in decode mode, we decode one example at a time. 304 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. 305 | if FLAGS.mode == 'decode': 306 | FLAGS.batch_size = FLAGS.beam_size 307 | 308 | # If single_pass=True, check we're in decode mode 309 | if FLAGS.single_pass and FLAGS.mode != 'decode': 310 | raise Exception("The single_pass flag should only be True in decode mode") 311 | 312 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs 313 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 314 | 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 315 | 'pointer_gen'] 316 | hps_dict = {} 317 | for key, val in FLAGS.__flags.items(): # for each flag 318 | if key in hparam_list: # if it's in the list 319 | hps_dict[key] = val # add it to the dict 320 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) 321 | 322 | # Create a batcher object that will create minibatches of data 323 | # MAX: This line is commented because we only need to run decode and a batcher will be run again there. Might have to uncomment once we need to train 324 | # batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass) 325 | 326 | tf.set_random_seed(111) # a seed value for randomness 327 | 328 | if hps.mode == 'train': 329 | print("creating model...") 330 | model = SummarizationModel(hps, vocab) 331 | setup_training(model, batcher) 332 | elif hps.mode == 'eval': 333 | model = SummarizationModel(hps, vocab) 334 | run_eval(model, batcher, vocab) 335 | elif hps.mode == 'decode': 336 | decode_model_hps = hps # This will be the hyperparameters for the decoder model 337 | decode_model_hps = hps._replace( 338 | max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries 339 | for line in sys.stdin: # Each line is a data file path 340 | tf.reset_default_graph() 341 | model = SummarizationModel(decode_model_hps, vocab) 342 | # Create a batcher object that will create minibatches of data 343 | batcher = Batcher(line.strip(), vocab, hps, single_pass=True) 344 | decoder = BeamSearchDecoder(model, batcher, vocab) 345 | decoder.decode() 346 | print('', flush=True) 347 | else: 348 | raise ValueError("The 'mode' flag must be one of train/eval/decode") 349 | 350 | 351 | if __name__ == '__main__': 352 | tf.app.run() 353 | -------------------------------------------------------------------------------- /core/getpoint/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains some utility functions""" 18 | 19 | import tensorflow as tf 20 | import time 21 | import os 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | def get_config(): 25 | """Returns config for tf.session""" 26 | config = tf.ConfigProto(allow_soft_placement=True) 27 | config.gpu_options.allow_growth=True 28 | return config 29 | 30 | def load_ckpt(saver, sess, ckpt_dir=""): 31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name.""" 32 | while True: 33 | try: 34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None 35 | ckpt_dir = os.path.join(FLAGS.ckpt_dir, ckpt_dir) 36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename) 37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) 38 | saver.restore(sess, ckpt_state.model_checkpoint_path) 39 | return ckpt_state.model_checkpoint_path 40 | except: 41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10) 42 | time.sleep(10) 43 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import os 19 | from pathlib import Path 20 | from subprocess import Popen, PIPE # nosec - B404:blacklist] 21 | from tempfile import NamedTemporaryFile, TemporaryDirectory 22 | from threading import Lock 23 | from flask import abort 24 | 25 | from maxfw.model import MAXModelWrapper 26 | from config import ASSET_DIR, DEFAULT_MODEL_PATH, DEFAULT_VOCAB_PATH, MODEL_NAME 27 | 28 | from .getpoint.convert import convert_to_bin 29 | from .util import process_punctuation 30 | 31 | logger = logging.getLogger() 32 | 33 | 34 | class ModelWrapper(MAXModelWrapper): 35 | 36 | MODEL_META_DATA = { 37 | 'id': 'max-text-summarizer', 38 | 'name': 'MAX Text Summarizer', 39 | 'description': '{} TensorFlow model trained on CNN/Daily Mail Data'.format(MODEL_NAME), 40 | 'type': 'Text Analysis', 41 | 'source': 'https://developer.ibm.com/exchanges/models/all/max-text-summarizer/', 42 | 'license': 'Apache V2' 43 | } 44 | 45 | _tempfile_mutex = Lock() 46 | 47 | def __init__(self, path=DEFAULT_MODEL_PATH): 48 | logger.info('Loading model from: %s...', path) 49 | 50 | self.log_dir = TemporaryDirectory() 51 | self.p_summarize = Popen(['python', 'core/getpoint/run_summarization.py', '--mode=decode', # nosec - B603 52 | '--ckpt_dir={}'.format(ASSET_DIR), 53 | '--vocab_path={}'.format(DEFAULT_VOCAB_PATH), 54 | '--log_root={}'.format(self.log_dir.name)], 55 | stdin=PIPE, stdout=PIPE) 56 | 57 | def __del__(self): 58 | self.p_summarize.stdin.close() 59 | self.log_dir.cleanup() 60 | 61 | def _pre_process(self, x): 62 | return process_punctuation(x) 63 | 64 | def _predict(self, x): 65 | if all(not c.isalpha() for c in x): 66 | abort(400, 'Input file contains no alphabetical characters.') 67 | 68 | with __class__._tempfile_mutex: 69 | # Create temporary file for inter-process communication. This 70 | # procedure must be not executed by two threads at the same time to 71 | # avoid file name conflicts. 72 | try: 73 | # Make use of tmpfs on Linux if available. 74 | directory = Path("/dev/shm/max-ts-{}".format(os.getpid())) # nosec - B108:hardcoded_tmp_directory 75 | # The following two lines may also raise IOError 76 | directory.mkdir(parents=True, exist_ok=True) 77 | bin_file = NamedTemporaryFile( 78 | prefix='generated_sample_', suffix='.bin', 79 | dir=directory.absolute()) 80 | except IOError as e: 81 | logger.warning('Failed to create temporary file in RAM. ' 82 | 'Fall back to disk files: %s', e) 83 | directory = Path("./assets/max-ts-{}".format(os.getpid())) 84 | directory.mkdir(parents=True, exist_ok=True) 85 | bin_file = NamedTemporaryFile( 86 | prefix='generated_sample_', suffix='.bin', 87 | dir=directory.absolute()) 88 | 89 | with bin_file: 90 | 91 | bin_file_path = bin_file.name 92 | 93 | convert_to_bin(x, bin_file_path) 94 | 95 | try: 96 | self.p_summarize.stdin.write(bin_file_path.encode('utf8')) 97 | self.p_summarize.stdin.write(b'\n') 98 | self.p_summarize.stdin.flush() 99 | # One paragraph at a time under our usage. 100 | summary = self.p_summarize.stdout.readline() 101 | except (IOError, BrokenPipeError) as e: 102 | err_msg = 'Failed to communicate with the summarizer.' 103 | logger.error(err_msg + ' %s', e) 104 | abort(400, err_msg) 105 | 106 | summary = summary.decode('utf-8') 107 | 108 | if len(summary) <= len(x): 109 | return summary 110 | 111 | # Truncate the summary length to be no longer than x. Note that x already has its punctuations processed. 112 | if not summary[len(x)].isspace(): # We are truncating the middle of a word. Also remove the last word 113 | return summary[:len(x)].rsplit(maxsplit=1)[0] 114 | else: 115 | return summary[:len(x)] 116 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import re 18 | 19 | 20 | def process_punctuation(x): 21 | "this model does not like punctuation touching characters." 22 | return re.sub('([.,!?()])', r' \1 ', x) # https://stackoverflow.com/a/3645946/ 23 | -------------------------------------------------------------------------------- /dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png -------------------------------------------------------------------------------- /docs/swagger-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/docs/swagger-screenshot.png -------------------------------------------------------------------------------- /max-text-summarizer.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: max-text-summarizer 5 | spec: 6 | selector: 7 | app: max-text-summarizer 8 | ports: 9 | - port: 5000 10 | type: NodePort 11 | --- 12 | apiVersion: apps/v1 13 | kind: Deployment 14 | metadata: 15 | name: max-text-summarizer 16 | labels: 17 | app: max-text-summarizer 18 | spec: 19 | selector: 20 | matchLabels: 21 | app: max-text-summarizer 22 | replicas: 1 23 | template: 24 | metadata: 25 | labels: 26 | app: max-text-summarizer 27 | spec: 28 | containers: 29 | - name: max-text-summarizer 30 | image: quay.io/codait/max-text-summarizer:latest 31 | ports: 32 | - containerPort: 5000 33 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.0 2 | requests==2.25.0 3 | flake8==3.8.4 4 | bandit==1.6.2 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.0.0 # Do not upgrade. We use features specifically in this version (e.g., the format of the training data). See https://github.com/becxer/pointer-generator/#about-this-code 2 | pyrouge==0.1.3 3 | -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- 1 | # Sample Details 2 | 3 | ## Test examples 4 | 5 | This directory contains test samples for the model. These test samples are part of the [CNN / Daily Mail](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail) dataset and are licensed under the [MIT](https://opensource.org/licenses/MIT) license. 6 | -------------------------------------------------------------------------------- /samples/sample1.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": [ 3 | "nick gordon 's stepfather has revealed in a new interview that he fears for the life of bobbi kristina 's troubled fiance . it has been reported that gordon , 25 , has threatened suicide and has been taking xanax since whitney houston 's daughter was found unconscious in a bathtub in late january . on wednesday , access hollywood spoke exclusively to gordon 's stepfather about his son 's state of mind , asking him if he fears for nick 's life . scroll down for video . speaking out : nick gordon 's father -lrb- left and right -rrb- gave an interview about the 25-year-old fiance of bobbi kristina brown , who reportedly had threatened suicide . access hollywood will air the complete interview thursday . ' i fear ... that if something happens to bobbi kristina , like , if she does n't pull it through , then i will fear for my son 's life , ' the stepfather stated . access hollywood will air the complete interview thursday . speaking to dr phil mcgraw gordon , 25 , said : ` i lost the most legendary singer ever and i 'm scared to lose __krissy__ . i want to let all you guys know i did everything possible in the world to protect them . ' gordon 's impassioned assertion was made during a dramatic intervention staged by dr phil last week and due to air today . weeping and wailing and at times incoherent , gordon had believed that he was to be interviewed by dr phil . gordon had told the eminent psychologist that he wanted to tell his side of the story . he has felt ` vilified ' and ` depicted as a monster ' since january 31 when bobby kristina , 22 , was found face down and unresponsive in the __bath-tub__ at the roswell , georgia home he and bobby kristina shared . breakdown : with his mother , michelle , by his side nick gordon struggles to stay coherent . troubling preview : a trailer for the interview of bobbi kristina brown 's boyfriend nick gordon by dr phil mcgraw was released sunday night . swinging from crying to flying into a rage , the disturbing 30 second video shows a very troubled young man in a lot of pain as he is questioned by dr phil . but on learning of gordon 's emotional , mental and physical state dr phil could not ` in good conscience ' go on with the interview . according to gordon 's mother , michelle , who was by her son 's side during the highly emotional intervention , ` nicholas is at a breaking point . he can not bear not being without __krissy__ , by her side . he 's dealing with it by drinking . i 've begged him to stop . ' dr phil told gordon , who admitted to having taken xanax and to have drunk heavily prior to the meeting : ` nick , you 're out of control . you 've threatened suicide . you deserved to get some help because if you do n't , you know you 're going to wind up dead . ' he added : ` you 've got to get yourself cleaned up . you got to man up and straighten up . ' yesterday dailymail.com revealed that , barely 12 hours before the dramatic encounter with dr phil , gordon was so drunk and high he was unable to walk . the disturbing scenes captured on video showed the true extent of a deterioration described by dr phil as ` exponential . ' in the show , due to broadcast wednesday , a weeping and wailing gordon admitted that he has twice tried to kill himself and confessed : ` i 'm so sorry for everything . ' asked if he still intended to kill himself he said : ` if anything happens to __krissi__ i will . ' he said : ` my pain is horrible . my heart hurts . i have panic attacks . ' in recent weeks gordon has twice overdosed on a mixture of xanax , alcohol and prescription sleeping pills . gordon had agreed to meet with dr phil believing that he was there to be interviewed . according to dr phil : ` he felt like he was being vilified and presented as some sort of monster . ' instead , on learning of gordon 's rapidly deteriorating mental , emotional and physical state , the eminent psychologist staged an intervention and he is now in rehab . dr phil stated : ' i do n't think he has any chance of turning this round . left to his own devices he will be dead within the week . ' gordon 's mother , michelle , was by her son 's side as he __alternated__ between compliant and aggressive -- at one point threatening to attack camera men as they filmed . she described her son as ` at breaking point . ' she said : ` he can not take too much more of not being able to see __kriss__ . he blames himself . he 's torn up and carrying guilt . ` he 's dealing with it by drinking . i 've begged him to stop . i 've tried to help him but he ca n't let go of the guilt . ' leaning towards gordon 's mother , dr phil said : ` you and i have one mission with one possible outcome and that 's for him to agree to go to rehab to deal with his depression , his guilt ... and to get clean and sober . ' he added : ` his life absolutely hangs in the balance . ' questions still rage regarding just what happened in the early hours of january 31st that led to bobbi kristina , 22 , ending up face down in her tub . just two days ago bobbi kristina 's aunt , __leolah__ brown , made a series of explosive facebook posts in which she alleged that the family had ` strong __evidene__ of foul play ' relating to gordon 's role in events . she posted : ` nick gordon is very disrespectful and inconsiderate ! especially to my family . moreover , he has done things to my niece that i never thought he had in him to do ! ' __leolah__ made her claims in response to being invited onto dr phil 's show . tough to watch : the young man sobs and shakes at times through the trailer . in her message she wrote : ` with all due respect , nick gordon is under investigation for the attempted murder of my niece ... . we have strong evidence of foul play . ' marks were found on her face and arms , marks that gordon has explained as the result of cpr which he administered to her for 15 minutes before emergency services arrived . and speaking to dr phil , gordon insisted : ' i did everything possible in the world to protect them -lsb- whitney and bobbi kristina -rsb- . ' railing against the decision by bobbi 's father , bobby brown , to ban him from his fiancée 's bedside at atlanta 's emory hospital , gordon said : ` my name will be the first she calls . ' according to his mother , michelle , gordon can not forgive himself for his ` failure ' to revive bobbi kristina . his guilt is compounded by the eerily similar situation in which he found himself almost exactly three years earlier . ' i hate bobby brown ! ' his tears quickly turn to anger when he brings up his girlfriend 's father bobby brown , who has blocked him from seeing the 22-year-old in the hospital . ` are you drinking ? ' the emotional man 's anger only increases as dr. phil asks about nick 's sobriety . refuses to go : the 25-year-old aspiring rapper storms out when told he needs help . resigned : ultimately , gordon concedes that he needs help and leaves for a rehab facility . then , on 11 february 2012 , it was gordon who tried to resuscitate whitney when she was found unresponsive in her __bath-tub__ just hours before she was due to attend a grammy awards party . speaking to dr phil , michelle said : ` nicholas just continually expresses how much he has failed whitney . ` he administered cpr -lsb- to whitney -rsb- and he called me when he was standing over her body . he could n't understand why he could n't revive her . he said , `` mommy why ? i could n't get the air into her lungs . '' gordon has now entered a rehab facility in atlanta . meanwhile his fiancée continues to fight for her life in a medically-induced coma . it is now six weeks since bobbi kristina -- __krissi__ as she was known to friends - was found face down and unresponsive in the __bath-tub__ of the home she shared with gordon in roswell , georgia . unlike whitney 's death that was ruled accidental , police are treating bobbi kristina 's near drowning and injuries allegedly sustained by the 22-year-old as a criminal investigation . while bobbi kristina fights for her life a troubling picture of her relationship with nick , of drug use and domestic turbulence in the weeks leading up to the incident , has emerged . in an interview with the sun a friend of the couple , steven __stepho__ , claimed they were using various drugs daily including heroin , xanax , pot and heroin substitute __roxicodone__ . sounding eerily similar to bobbi 's mother and father 's __drug-dependent__ relationship , the friend said the pair used whatever they could get their hands on . ` bobbi and nick would spend a lot on drugs every day , it just depended on how much money they had . it was n't unusual for them to spend $ 1,000 a day on drugs . ` there were times when it got really bad - they would be completely passed out for hours , just lying there on the bed . there were times when she would be so knocked out she would burn herself with a cigarette and not even notice . she was always covered in cigarette burns . ' their relationship was also very volatile . ` when whitney died nick was left with nothing , so he knew he had to control __krissi__ to get access to the money . she 'd do whatever he told her . ` he was very manipulative and would even use the drugs to control her . they would argue a lot and there were times when he would be violent with her and push her around . ' according to __stepho__ : ` but __krissi__ really loved him because he was there to fill the gap left by her mother . she was not close to her father and did not have anyone else close to her . nick knew this and took advantage of it . ' troubled relationship : a friend of the couple , steven __stepho__ , claimed the pair were using various drugs daily including heroin , xanax , pot and heroin substitute __roxicodone__ ." 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /samples/sample2.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": [ 3 | "The Korean Air executive who kicked up a fuss over a bag of nuts will resign from her remaining posts with the airline , the company chairman -- who is also her father -- said Friday . The executive , Heather Cho , found herself at the center of a media storm after she ordered that a plane turn back to the gate and that a flight attendant be removed -- all because she was served nuts in a bag instead of on a plate in first class . Although her role put her in charge of in-flight service , she was only a passenger on the flight and was not flying in an official capacity . The incident , which took place last week at New York 's JFK airport , stirred anger among the South Korean public over Cho 's behavior . Cho , whose Korean name is Cho Hyun-ah , resigned Tuesday from the airline 's catering and in-flight sales business , and from its cabin service and hotel business divisions , the company said . But the 40-year-old kept her title as a vice president of the national carrier , according to company spokesman . That 's going to change , her father , Cho Yang-ho , said Friday as he made a public apology for what happened . She will be resigning from the vice president job and positions held in affiliate companies , he said . Asked by reporters how the incident could have happened , the company chairman blamed himself , saying he 'd raised her badly . ` Outburst of anger ' A local English-language newspaper , The Korea Times , said her behavior has deepened public resentment of South Korea 's large family-owned corporations , known as chaebol . `` Through her outburst of anger , she not only caused inconvenience to KAL passengers , but also to those on other flights , '' the newspaper said in an editorial Tuesday . The most annoying type of airline passenger is ... South Korean authorities are now investigating the incident , which occurred on a flight due to take off for Incheon International Airport near Seoul . Cho arrived at the Ministry of Land , Infrastructure and Transport on Friday as part of the investigation , according to local TV coverage . She spoke in such a low voice that it was inaudible from the TV footage . ` An excessive act ' Korean Air apologized for any inconvenience to those on the flight and said there had been no safety issues involved . The plane arrived at its destination 11 minutes behind schedule , according to the South Korean news agency Yonhap . `` Even though it was not an emergency situation , backing up the plane to order an employee to deplane was an excessive act , '' the airline said earlier this week . `` We will re-educate all our employees to make sure service within the plane meets high standards . '' The airline also issued an apology on Heather Cho 's behalf , Yonhap reported , in which she asked for forgiveness and said she would take `` full responsibility '' for the incident . According to her biography on the website of Nanyang Technological University , Heather Cho joined the airline in 1999 and has since been `` actively involved in establishing a new corporate identity for Korean Air . '' She studied at Cornell University and the University of Southern California . @highlight Airline chairman : Heather Cho will be resigning from all posts with the company @highlight Korean Air said previously she had resigned from some roles but was keeping her VP title @highlight The chairman , who is also her father , blames himself , saying he raised her badly @highlight Cho ordered a plane back to the gate after a flight attendant served nuts in a bag" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /samples/sample3.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": [ 3 | "Attorney General Eric Holder urged other countries to enact new criminal laws to help prevent possible terrorist attacks from returning Syrian fighters . In a speech in Oslo , Norway , on Tuesday , Holder highlighted the threat from domestic extremists such as a lone terrorist who three years ago bombed government buildings in the Norwegian capital and gunned down people at a youth camp , killing 77 people in all . Holder said the United States , along with Norway and other European countries , face similar danger from `` violent extremists fighting today in Syria , Iraq or other locations '' who `` may seek to commit acts of terror tomorrow in our countries as well . '' U.S. intelligence estimates that nearly 7,000 foreign fighters have traveled to Syria , including dozens from the United States . The issue of Syrian fighters from Western countries is dominating Holder 's trip this week to Europe , which includes a meeting in London of attorneys general from the U.S. , UK , Canada , Australia and New Zealand . Open visa access among European countries and the U.S. means `` the problem of fighters in Syria returning to any of our countries is a problem for all of our countries , '' Holder said in his speech , excerpts of which were provided in advance by the Justice Department . `` The Syrian conflict has turned that region into a cradle of violent extremism , '' he said . `` But the world can not simply sit back and let it become a training ground from which our nationals can return and launch attacks . '' To combat the threat , Holder is calling on countries to pass laws to criminalize the preparatory steps that suspects often take before an attack , and to allow police to conduct undercover investigations . The U.S. has a law that makes it a crime to provide `` material support '' to terrorists , including supplying money or weapons , or helping to plot an attack . Similar laws are now on the books in Norway and France . Holder also cited the FBI 's success in using undercover sting operations , which have drawn controversy in the U.S. but have been successful in prosecuting dozens of suspects who had admitted to plans to commit terrorism . Many countries do n't allow such operations . But Holder said the use of such operations could help countries thwart attacks . He said the U.S. has already used these tools to carry out prosecutions of people who sought to travel to join the fight in Syria . `` These operations are conducted with extraordinary care and precision , ensuring that law enforcement officials are accountable for the steps they take and that suspects are neither entrapped nor denied legal protections , '' Holder said . He also called on countries to share information with Interpol and one another about their nationals who try to travel to Syria to fight and those who return . And he urged countries to come up with counter-radicalization programs that try to reach communities where young people may be exposed to extremists . `` We must seek to stop individuals from becoming radicalized in the first place by putting in place strong programs to counter violent extremism in its earliest stages , '' Holder said . @highlight Eric Holder urges countries to enact new criminal laws that clamp down on suspects @highlight He cites `` violent extremists fighting today in Syria , Iraq or other locations '' @highlight U.S. intelligence estimates that nearly 7,000 foreign fighters have traveled to Syria @highlight Holder urges nations to share information about nationals who try to travel to Syria to fight" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /sha512sums.txt: -------------------------------------------------------------------------------- 1 | 3fbf53cc7c32402a88d4b33762dc84e9ed1036a473cbf64d426805f421686e4c97db31cdc8cba4e043f54b5f1610dbac2529d49952277a562759dc748ab2cbf6 assets/checkpoint 2 | 07aaa28da583bdae86c3557f6563653dccc9b579bbeaf3ab7e907b764170579fe1718aeb21e9281dcbf85905ee21e10c7422cabfcdd85ea9a750bb8bfe5bdf9c assets/graph.pbtxt 3 | 732c621da1f2d5fe26c00cbf224a82e99302f55db98c2359bd10c07e14022b609440afbeba089451d4f766f287d35abaced5e0f7673540afff84f510c3d28e88 assets/model-238410.data-00000-of-00001 4 | b88a63d72d4a35dba130cc2b0b7094ab7fae2da058892b68503dab821c5eca6c4ab89c71481a71a7e497e686069396ba57041dd1903a89d29e2f0cd91f9236ff assets/model-238410.index 5 | 9d62308623b4833a519d0b547bbda6387cb6ce9bfaf32f6233d34b79f13788d798d68603eae48f3d56c0042584e917c7378027203aabbb470289771aeeaae548 assets/model-238410.meta 6 | cb5558e40557be772442c4079575ac41c69c991a238fd54c1de44b9bff47823279766aadb5f5107c260627a25129c7628cda84fec6f775fb68285047a599977b assets/projector_config.pbtxt 7 | 5976b96ab0d2a5ea306e30f2e3bbbc866440a0fa0a5e13e040a1c43e223e6c0dc4fb7d7284822759a9c34b651b8fa61087b5ea5add0ebfdd61e46d2107597269 assets/vocab 8 | 4b404c4f7de95547b0d82bc150f705ed61878fa34ac2941a5559e99f59bf2e231d20b71e09a3aadb2586992dce5ff5043db3eff24dec1655eaf4862685dc483e assets/vocab_metadata.tsv 9 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import glob 18 | import json 19 | import pytest 20 | import requests 21 | 22 | from core.util import process_punctuation 23 | 24 | 25 | MODEL_PREDICT_ENDPOINT = 'http://localhost:5000/model/predict' 26 | 27 | 28 | def test_swagger(): 29 | 30 | model_endpoint = 'http://localhost:5000/swagger.json' 31 | 32 | r = requests.get(url=model_endpoint) 33 | assert r.status_code == 200 34 | assert r.headers['Content-Type'] == 'application/json' 35 | 36 | json = r.json() 37 | assert 'swagger' in json 38 | assert json.get('info').get('title') == 'MAX Text Summarizer' 39 | assert json.get('info').get('description') == 'Generate a summarized description of a body of text.' 40 | 41 | 42 | def test_metadata(): 43 | 44 | model_endpoint = 'http://localhost:5000/model/metadata' 45 | 46 | r = requests.get(url=model_endpoint) 47 | assert r.status_code == 200 48 | 49 | metadata = r.json() 50 | assert metadata['id'] == 'max-text-summarizer' 51 | assert metadata['name'] == 'MAX Text Summarizer' 52 | assert metadata['description'] == 'get_to_the_point TensorFlow model trained on CNN/Daily Mail Data' 53 | assert metadata['license'] == 'Apache V2' 54 | 55 | 56 | def test_predict_valid(): 57 | "Test prediction for valid input." 58 | short_text = 'Why not summarize?' 59 | with open('samples/sample1.json') as f: 60 | long_text = json.load(f)['text'][0] 61 | 62 | json_data = { 63 | 'text': [short_text, long_text] 64 | } 65 | 66 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data) 67 | 68 | assert r.status_code == 200 69 | response = r.json() 70 | assert response['status'] == 'ok' 71 | assert len(response['summary_text']) == 2 72 | for i, text in enumerate(json_data['text']): 73 | assert len(response['summary_text'][i]) <= len(process_punctuation(text)) 74 | 75 | 76 | def test_predict_sample(): 77 | "Test prediction for sample inputs." 78 | for sample in glob.iglob('samples/*.json'): 79 | with open(sample) as f: 80 | json_data = json.load(f) 81 | 82 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data) 83 | 84 | assert r.status_code == 200 85 | response = r.json() 86 | assert response['status'] == 'ok' 87 | assert len(response['summary_text']) == len(json_data['text']) 88 | for i, text in enumerate(json_data['text']): 89 | assert len(response['summary_text'][i]) <= len(process_punctuation(text)) 90 | 91 | 92 | def test_predict_invalid_input_no_string(): 93 | "Test invalid input: no string." 94 | json_data = {} 95 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data) 96 | assert r.status_code == 400 97 | 98 | 99 | def test_predict_invalid_empty_string(): 100 | "Test invalid input: empty and blank strings." 101 | for s in ('', ' \t\f\r\n'): 102 | for text in ([s], ['some text', s]): 103 | json_data = {'text': text} 104 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data) 105 | assert r.status_code == 400 106 | 107 | 108 | if __name__ == '__main__': 109 | pytest.main([__file__]) 110 | --------------------------------------------------------------------------------