├── .bandit
├── .dockerignore
├── .gitignore
├── .travis.yml
├── Dockerfile
├── LICENSE
├── README.md
├── api
├── __init__.py
├── metadata.py
└── predict.py
├── app.py
├── config.py
├── core
├── __init__.py
├── getpoint
│ ├── attention_decoder.py
│ ├── batcher.py
│ ├── beam_search.py
│ ├── convert.py
│ ├── data.py
│ ├── decode.py
│ ├── inspect_checkpoint.py
│ ├── model.py
│ ├── run_summarization.py
│ └── util.py
├── model.py
└── util.py
├── dependabot.yml
├── docs
├── deploy-max-to-ibm-cloud-with-kubernetes-button.png
└── swagger-screenshot.png
├── max-text-summarizer.yaml
├── requirements-test.txt
├── requirements.txt
├── samples
├── README.md
├── sample1.json
├── sample2.json
└── sample3.json
├── sha512sums.txt
└── tests
└── test.py
/.bandit:
--------------------------------------------------------------------------------
1 | [bandit]
2 | exclude: /tests,/training
3 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | README.*
2 | __pycache__
3 | .gitignore
4 | .git/
5 | tests/
6 | samples/
7 | docs/
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | .idea
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - 3.6
4 | services:
5 | - docker
6 | install:
7 | - docker build -t max-text-summarizer .
8 | - docker run -it -d --rm -p 5000:5000 max-text-summarizer
9 | - pip install -r requirements-test.txt
10 | script:
11 | - flake8 . --max-line-length=127 --exclude=./core/getpoint
12 | - bandit -r .
13 | - sleep 30
14 | # "python -m" will add current working directory to sys.path.
15 | # See https://docs.pytest.org/en/latest/usage.html#calling-pytest-through-python-m-pytest
16 | - python -m pytest tests/test.py
17 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | FROM quay.io/codait/max-base:v1.1.3
18 |
19 | # Upgrade packages to meet security criteria
20 | RUN apt-get update && apt-get upgrade -y && rm -rf /var/lib/apt/lists/*
21 |
22 | # Fill in these with a link to the bucket containing the model and the model file name
23 | ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-text-summarizer/1.0.0
24 | ARG model_file=assets.tar.gz
25 |
26 | RUN useradd --create-home max
27 | RUN chown -R max:max /opt/conda
28 | USER max
29 | WORKDIR /home/max
30 | RUN mkdir assets
31 |
32 | RUN wget -nv --show-progress --progress=bar:force:noscroll ${model_bucket}/${model_file} --output-document=assets/${model_file} && \
33 | tar -x -C assets/ -f assets/${model_file} -v && rm assets/${model_file}
34 |
35 | COPY requirements.txt .
36 | RUN pip install -r requirements.txt
37 |
38 | COPY . .
39 |
40 | # check file integrity
41 | RUN sha512sum -c sha512sums.txt
42 |
43 | EXPOSE 5000
44 |
45 | CMD python app.py
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.com/IBM/MAX-Text-Summarizer)
2 | [](http://max-text-summarizer.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud)
3 |
4 | [
](http://ibm.biz/max-to-ibm-cloud-tutorial)
5 |
6 | # IBM Developer Model Asset Exchange: Text Summarizer
7 |
8 | This repository contains code to instantiate and deploy a text summarization model. The model takes a JSON input that encapsulates some text snippets and returns a text summary that represents the key information or message in the input text.
9 | The model was trained on the [CNN / Daily Mail](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail) dataset.
10 | The model has a vocabulary of approximately 200k words. The model is based on the ACL 2017 paper, [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368).
11 |
12 | The model files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-text-summarizer/1.0.0/assets.tar.gz). The code in this repository deploys the model as a web service in a Docker container. This repository was developed as part of the [IBM Developer Model Asset Exchange](https://developer.ibm.com/code/exchanges/models/) and the public API is powered by [IBM Cloud](https://ibm.biz/Bdz2XM).
13 |
14 | ## Model Metadata
15 | | Domain | Application | Industry | Framework | Training Data | Input Data Format |
16 | | ------------- | -------- | -------- | --------- | --------- | -------------- |
17 | | NLP | Text Summarization | General | TensorFlow | CNN / Daily Mail | Text |
18 |
19 |
20 | ## References:
21 |
22 | * _A. See, P. J. Liu, C. D. Manning_, [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368), ACL, 2017.
23 |
24 | * [The text summarization repository](https://github.com/becxer/pointer-generator/)
25 |
26 | ## Licenses
27 |
28 | | Component | License | Link |
29 | | ------------- | -------- | -------- |
30 | | This Repository | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/IBM/MAX-Text-Summarizer/blob/master/LICENSE) |
31 | | Third Party Code | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/becxer/pointer-generator/blob/master/LICENSE.txt) |
32 | | Pre-Trained Model Weights | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](https://github.com/becxer/pointer-generator/) |
33 | | Training Data | [MIT](https://opensource.org/licenses/MIT) | [LICENSE](https://github.com/abisee/cnn-dailymail/blob/master/LICENSE.md) | [CNN / Daily Mail ](https://github.com/abisee/cnn-dailymail) |
34 |
35 | ## Pre-requisites:
36 |
37 | * `docker`: The [Docker](https://www.docker.com/) command-line interface. Follow the [installation instructions](https://docs.docker.com/install/) for your system.
38 | * The minimum recommended resources for this model is 4GB Memory and 4 CPUs.
39 | * If you are on x86-64/AMD64, your CPU must support AVX at the minimum.
40 |
41 | # Steps
42 |
43 | 1. [Deploy from Quay](#deploy-from-quay)
44 | 2. [Deploy on Kubernetes](#deploy-on-kubernetes)
45 | 3. [Run Locally](#run-locally)
46 |
47 | ## Deploy from Quay
48 |
49 | To run the docker image, which automatically starts the model serving API, run:
50 |
51 | ```
52 | $ docker run -it -p 5000:5000 quay.io/codait/max-text-summarizer
53 | ```
54 |
55 | This will pull a pre-built image from the Quay.io container registry (or use an existing image if already cached locally) and run it.
56 | If you'd rather checkout and build the model locally you can follow the [run locally](#run-locally) steps below.
57 |
58 | ## Deploy on Kubernetes
59 |
60 | You can also deploy the model on Kubernetes using the latest docker image on Quay.
61 |
62 | On your Kubernetes cluster, run the following commands:
63 |
64 | ```
65 | $ kubectl apply -f https://github.com/IBM/MAX-Text-Summarizer/raw/master/max-text-summarizer.yaml
66 | ```
67 |
68 | The model will be available internally at port `5000`, but can also be accessed externally through the `NodePort`.
69 |
70 | A more elaborate tutorial on how to deploy this MAX model to production on [IBM Cloud](https://ibm.biz/Bdz2XM) can be found [here](http://ibm.biz/max-to-ibm-cloud-tutorial).
71 |
72 | ## Run Locally
73 |
74 | 1. [Build the Model](#1-build-the-model)
75 | 2. [Deploy the Model](#2-deploy-the-model)
76 | 3. [Use the Model](#3-use-the-model)
77 | 4. [Development](#4-development)
78 | 5. [Cleanup](#5-cleanup)
79 |
80 |
81 | ### 1. Build the Model
82 |
83 | Clone this repository locally. In a terminal, run the following command:
84 |
85 | ```
86 | $ git clone https://github.com/IBM/MAX-Text-Summarizer.git
87 | ```
88 |
89 | Change directory into the repository base folder:
90 |
91 | ```
92 | $ cd MAX-Text-Summarizer
93 | ```
94 |
95 | To build the docker image locally, run:
96 |
97 | ```
98 | $ docker build -t max-text-summarizer .
99 | ```
100 |
101 | All required model assets will be downloaded during the build process. _Note_ that currently this docker image is CPU only (we will add support for GPU images later).
102 |
103 |
104 | ### 2. Deploy the Model
105 |
106 | To run the docker image, which automatically starts the model serving API, run:
107 |
108 | ```
109 | $ docker run -it -p 5000:5000 max-text-summarizer
110 | ```
111 |
112 | ### 3. Use the Model
113 |
114 | The API server automatically generates an interactive Swagger documentation page. Go to `http://localhost:5000` to load it. From there you can explore the API and also create test requests.
115 |
116 | Use the `model/predict` endpoint to load some seed text (you can use one of the test files from the `samples` folder) and get predicted output from the API.
117 |
118 | 
119 |
120 | You can also test it on the command line, for example:
121 |
122 | ```bash
123 | $ curl -d @samples/sample1.json -H "Content-Type: application/json" -XPOST http://localhost:5000/model/predict
124 | ```
125 |
126 | You should see a JSON response like that below:
127 |
128 | ```json
129 | {
130 | "status": "ok",
131 | "summary_text": [
132 | ["nick gordon 's father -lrb- left and right -rrb- gave an interview about the 25-year-old fiance of bobbi kristina brown . it has been reported that gordon , 25 , has threatened suicide and has been taking xanax since . whitney houston 's daughter was found unconscious in a bathtub in january . on wednesday , access hollywood spoke exclusively to gordon 's stepfather about his son 's state of mind . "]
133 | ]
134 | }
135 | ```
136 |
137 | The text summarizer preserves in the output summary text some special characters such as `-lrb-` (representing `(`), `-rrb-` (representing `)`), etc. that appear in the input [sample](samples/sample1.json), which is excerpted from the [Daily Mail](https://github.com/abisee/cnn-dailymail) dataset.
138 |
139 | ### 4. Development
140 |
141 | To run the Flask API app in debug mode, edit `config.py` to set `DEBUG = True` under the application settings. You will then need to rebuild the docker image (see [step 1](#1-build-the-model)).
142 |
143 | ### 5. Cleanup
144 |
145 | To stop the Docker container, type `CTRL` + `C` in your terminal.
146 |
147 | ## Resources and Contributions
148 |
149 | If you are interested in contributing to the Model Asset Exchange project or have any queries, please follow the instructions [here](https://github.com/CODAIT/max-central-repo).
150 |
--------------------------------------------------------------------------------
/api/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .metadata import ModelMetadataAPI # noqa
18 | from .predict import ModelPredictAPI # noqa
19 |
--------------------------------------------------------------------------------
/api/metadata.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from core.model import ModelWrapper
18 | from maxfw.core import MAX_API, MetadataAPI, METADATA_SCHEMA
19 |
20 |
21 | class ModelMetadataAPI(MetadataAPI):
22 |
23 | @MAX_API.marshal_with(METADATA_SCHEMA)
24 | def get(self):
25 | """Return the metadata associated with the model"""
26 | return ModelWrapper.MODEL_META_DATA
27 |
--------------------------------------------------------------------------------
/api/predict.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import logging
18 | from core.model import ModelWrapper
19 | from maxfw.core import MAX_API, PredictAPI
20 | from flask_restplus import fields
21 |
22 | input_parser = MAX_API.model('ModelInput', {
23 | 'text': fields.List(fields.String, required=True, description=(
24 | 'A list of input text to be summarized. '
25 | 'Each entry in the list is treated as a separate input text and so a summary result will be returned for each entry.'))
26 | })
27 |
28 | predict_response = MAX_API.model('ModelPredictResponse', {
29 | 'status': fields.String(required=True, description='Response status message.'),
30 | 'summary_text': fields.List(fields.String, required=True, description=(
31 | 'Generated summary text. Each entry in the list is the summary result for the corresponding entry in the input list.'))
32 | })
33 |
34 | logger = logging.getLogger()
35 |
36 |
37 | class ModelPredictAPI(PredictAPI):
38 |
39 | model_wrapper = ModelWrapper()
40 |
41 | @MAX_API.doc('predict')
42 | @MAX_API.expect(input_parser, validate=True)
43 | @MAX_API.marshal_with(predict_response)
44 | def post(self):
45 | """Make a prediction given input data"""
46 | result = {'status': 'error'}
47 | result['summary_text'] = []
48 |
49 | input_json = MAX_API.payload
50 | texts = input_json['text']
51 | for text in texts:
52 | preds = self.model_wrapper.predict(text)
53 | result['summary_text'].append(preds)
54 |
55 | result['status'] = 'ok'
56 |
57 | return result
58 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from maxfw.core import MAXApp
18 | from api import ModelMetadataAPI, ModelPredictAPI
19 | from config import API_TITLE, API_DESC, API_VERSION
20 |
21 | max = MAXApp(API_TITLE, API_DESC, API_VERSION)
22 | max.add_api(ModelMetadataAPI, '/metadata')
23 | max.add_api(ModelPredictAPI, '/predict')
24 | max.run()
25 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import os
18 | import pathlib
19 |
20 | # Flask settings
21 | DEBUG = False
22 |
23 | # Flask-restplus settings
24 | RESTPLUS_MASK_SWAGGER = False
25 | SWAGGER_UI_DOC_EXPANSION = 'none'
26 |
27 | # API metadata
28 | API_TITLE = 'MAX Text Summarizer'
29 | API_DESC = 'Generate a summarized description of a body of text.'
30 | API_VERSION = '1.0.0'
31 |
32 | # default model
33 | MODEL_NAME = 'get_to_the_point'
34 | ASSET_DIR = pathlib.Path('./assets').absolute()
35 | DEFAULT_MODEL_PATH = os.path.join(ASSET_DIR, MODEL_NAME)
36 | DEFAULT_VOCAB_PATH = os.path.join(ASSET_DIR, 'vocab')
37 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/core/__init__.py
--------------------------------------------------------------------------------
/core/getpoint/attention_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file defines the decoder"""
18 |
19 | import tensorflow as tf
20 | from tensorflow.python.ops import variable_scope
21 | from tensorflow.python.ops import array_ops
22 | from tensorflow.python.ops import nn_ops
23 | from tensorflow.python.ops import math_ops
24 |
25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated.
26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention
27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None):
28 | """
29 | Args:
30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
31 | initial_state: 2D Tensor [batch_size x cell.state_size].
32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size].
33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1).
34 | cell: rnn_cell.RNNCell defining the cell function and size.
35 | initial_state_attention:
36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step).
37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step.
38 | use_coverage: boolean. If True, use coverage mechanism.
39 | prev_coverage:
40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage.
41 |
42 | Returns:
43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of
44 | shape [batch_size x cell.output_size]. The output vectors.
45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size].
46 | attn_dists: A list containing tensors of shape (batch_size,attn_length).
47 | The attention distributions for each decoder step.
48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False.
49 | coverage: Coverage vector on the last step computed. None if use_coverage=False.
50 | """
51 | with variable_scope.variable_scope("attention_decoder") as scope:
52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined
53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined
54 |
55 | # Reshape encoder_states (need to insert a dim)
56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size)
57 |
58 | # To calculate attention, we calculate
59 | # v^T tanh(W_h h_i + W_s s_t + b_attn)
60 | # where h_i is an encoder state, and s_t a decoder state.
61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t).
62 | # We set it to be equal to the size of the encoder states.
63 | attention_vec_size = attn_size
64 |
65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features
66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size])
67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size)
68 |
69 | # Get the weight vectors v and w_c (w_c is for coverage)
70 | v = variable_scope.get_variable("v", [attention_vec_size])
71 | if use_coverage:
72 | with variable_scope.variable_scope("coverage"):
73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size])
74 |
75 | if prev_coverage is not None: # for beam search mode with coverage
76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1)
77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3)
78 |
79 | def attention(decoder_state, coverage=None):
80 | """Calculate the context vector and attention distribution from the decoder state.
81 |
82 | Args:
83 | decoder_state: state of the decoder
84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1).
85 |
86 | Returns:
87 | context_vector: weighted sum of encoder_states
88 | attn_dist: attention distribution
89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1)
90 | """
91 | with variable_scope.variable_scope("Attention"):
92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size)
94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size)
95 |
96 | def masked_attention(e):
97 | """Take softmax of e then apply enc_padding_mask and re-normalize"""
98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length)
99 | attn_dist *= enc_padding_mask # apply mask
100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size)
101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize
102 |
103 | if use_coverage and coverage is not None: # non-first step of coverage
104 | # Multiply coverage vector by w_c to get coverage_features.
105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size)
106 |
107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length)
109 |
110 | # Calculate attention distribution
111 | attn_dist = masked_attention(e)
112 |
113 | # Update coverage vector
114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])
115 | else:
116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e
118 |
119 | # Calculate attention distribution
120 | attn_dist = masked_attention(e)
121 |
122 | if use_coverage: # first step of training
123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage
124 |
125 | # Calculate the context vector from attn_dist and encoder_states
126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size).
127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size])
128 |
129 | return context_vector, attn_dist, coverage
130 |
131 | outputs = []
132 | attn_dists = []
133 | p_gens = []
134 | state = initial_state
135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in
136 | context_vector = array_ops.zeros([batch_size, attn_size])
137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set.
138 | if initial_state_attention: # true in decode mode
139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input
140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector
141 | for i, inp in enumerate(decoder_inputs):
142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs))
143 | if i > 0:
144 | variable_scope.get_variable_scope().reuse_variables()
145 |
146 | # Merge input and previous attentions into one vector x of the same size as inp
147 | input_size = inp.get_shape().with_rank(2)[1]
148 | if input_size.value is None:
149 | raise ValueError("Could not infer input size from input: %s" % inp.name)
150 | x = linear([inp] + [context_vector], input_size, True)
151 |
152 | # Run the decoder RNN cell. cell_output = decoder state
153 | cell_output, state = cell(x, state)
154 |
155 | # Run the attention mechanism.
156 | if i == 0 and initial_state_attention: # always true in decode mode
157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call
158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update
159 | else:
160 | context_vector, attn_dist, coverage = attention(state, coverage)
161 | attn_dists.append(attn_dist)
162 |
163 | # Calculate p_gen
164 | if pointer_gen:
165 | with tf.variable_scope('calculate_pgen'):
166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar
167 | p_gen = tf.sigmoid(p_gen)
168 | p_gens.append(p_gen)
169 |
170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
171 | # This is V[s_t, h*_t] + b in the paper
172 | with variable_scope.variable_scope("AttnOutputProjection"):
173 | output = linear([cell_output] + [context_vector], cell.output_size, True)
174 | outputs.append(output)
175 |
176 | # If using coverage, reshape it
177 | if coverage is not None:
178 | coverage = array_ops.reshape(coverage, [batch_size, -1])
179 |
180 | return outputs, state, attn_dists, p_gens, coverage
181 |
182 |
183 |
184 | def linear(args, output_size, bias, bias_start=0.0, scope=None):
185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
186 |
187 | Args:
188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors.
189 | output_size: int, second dimension of W[i].
190 | bias: boolean, whether to add a bias term or not.
191 | bias_start: starting value to initialize the bias; 0 by default.
192 | scope: VariableScope for the created subgraph; defaults to "Linear".
193 |
194 | Returns:
195 | A 2D Tensor with shape [batch x output_size] equal to
196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
197 |
198 | Raises:
199 | ValueError: if some of the arguments has unspecified or wrong shape.
200 | """
201 | if args is None or (isinstance(args, (list, tuple)) and not args):
202 | raise ValueError("`args` must be specified")
203 | if not isinstance(args, (list, tuple)):
204 | args = [args]
205 |
206 | # Calculate the total size of arguments on dimension 1.
207 | total_arg_size = 0
208 | shapes = [a.get_shape().as_list() for a in args]
209 | for shape in shapes:
210 | if len(shape) != 2:
211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
212 | if not shape[1]:
213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
214 | else:
215 | total_arg_size += shape[1]
216 |
217 | # Now the computation.
218 | with tf.variable_scope(scope or "Linear"):
219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
220 | if len(args) == 1:
221 | res = tf.matmul(args[0], matrix)
222 | else:
223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix)
224 | if not bias:
225 | return res
226 | bias_term = tf.get_variable(
227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start))
228 | return res + bias_term
229 |
--------------------------------------------------------------------------------
/core/getpoint/batcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to process data into batches"""
18 |
19 | import queue as Queue
20 | from random import shuffle
21 | from threading import Thread
22 | import time
23 | import numpy as np
24 | import tensorflow as tf
25 | import data
26 |
27 |
28 | class Example(object):
29 | """Class representing a train/val/test example for text summarization."""
30 |
31 | def __init__(self, article, abstract_sentences, vocab, hps):
32 | """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.
33 |
34 | Args:
35 | article: source text; a string. each token is separated by a single space.
36 | abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
37 | vocab: Vocabulary object
38 | hps: hyperparameters
39 | """
40 | self.hps = hps
41 |
42 | # Get ids of special tokens
43 | start_decoding = vocab.word2id(data.START_DECODING)
44 | stop_decoding = vocab.word2id(data.STOP_DECODING)
45 |
46 | # Process the article
47 | article_words = article.split()
48 | if len(article_words) > hps.max_enc_steps:
49 | article_words = article_words[:hps.max_enc_steps]
50 | self.enc_len = len(article_words) # store the length after truncation but before padding
51 | self.enc_input = [vocab.word2id(w) for w in
52 | article_words] # list of word ids; OOVs are represented by the id for UNK token
53 |
54 | # Process the abstract
55 | abstract = ' '.join(abstract_sentences) # string
56 | abstract_words = abstract.split() # list of strings
57 | abs_ids = [vocab.word2id(w) for w in
58 | abstract_words] # list of word ids; OOVs are represented by the id for UNK token
59 |
60 | # Get the decoder input sequence and target sequence
61 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding,
62 | stop_decoding)
63 | self.dec_len = len(self.dec_input)
64 |
65 | # If using pointer-generator mode, we need to store some extra info
66 | if hps.pointer_gen:
67 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
68 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
69 |
70 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
71 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)
72 |
73 | # Overwrite decoder target sequence so it uses the temp article OOV ids
74 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding,
75 | stop_decoding)
76 |
77 | # Store the original strings
78 | self.original_article = article
79 | self.original_abstract = abstract
80 | self.original_abstract_sents = abstract_sentences
81 |
82 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
83 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated).
84 |
85 | Args:
86 | sequence: List of ids (integers)
87 | max_len: integer
88 | start_id: integer
89 | stop_id: integer
90 |
91 | Returns:
92 | inp: sequence length <=max_len starting with start_id
93 | target: sequence same length as input, ending with stop_id only if there was no truncation
94 | """
95 | inp = [start_id] + sequence[:]
96 | target = sequence[:]
97 | if len(inp) > max_len: # truncate
98 | inp = inp[:max_len]
99 | target = target[:max_len] # no end_token
100 | else: # no truncation
101 | target.append(stop_id) # end token
102 | if len(inp) != len(target):
103 | raise ValueError("len(inp) != len(target)")
104 | return inp, target
105 |
106 | def pad_decoder_inp_targ(self, max_len, pad_id):
107 | """Pad decoder input and target sequences with pad_id up to max_len."""
108 | while len(self.dec_input) < max_len:
109 | self.dec_input.append(pad_id)
110 | while len(self.target) < max_len:
111 | self.target.append(pad_id)
112 |
113 | def pad_encoder_input(self, max_len, pad_id):
114 | """Pad the encoder input sequence with pad_id up to max_len."""
115 | while len(self.enc_input) < max_len:
116 | self.enc_input.append(pad_id)
117 | if self.hps.pointer_gen:
118 | while len(self.enc_input_extend_vocab) < max_len:
119 | self.enc_input_extend_vocab.append(pad_id)
120 |
121 |
122 | class Batch(object):
123 | """Class representing a minibatch of train/val/test examples for text summarization."""
124 |
125 | def __init__(self, example_list, hps, vocab):
126 | """Turns the example_list into a Batch object.
127 |
128 | Args:
129 | example_list: List of Example objects
130 | hps: hyperparameters
131 | vocab: Vocabulary object
132 | """
133 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences
134 | self.init_encoder_seq(example_list, hps) # initialize the input to the encoder
135 | self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder
136 | self.store_orig_strings(example_list) # store the original strings
137 |
138 | def init_encoder_seq(self, example_list, hps):
139 | """Initializes the following:
140 | self.enc_batch:
141 | numpy array of shape (batch_size, <=max_enc_steps) containing integer ids (all OOVs represented by UNK id), padded to length of longest sequence in the batch
142 | self.enc_lens:
143 | numpy array of shape (batch_size) containing integers. The (truncated) length of each encoder input sequence (pre-padding).
144 | self.enc_padding_mask:
145 | numpy array of shape (batch_size, <=max_enc_steps), containing 1s and 0s. 1s correspond to real tokens in enc_batch and target_batch; 0s correspond to padding.
146 |
147 | If hps.pointer_gen, additionally initializes the following:
148 | self.max_art_oovs:
149 | maximum number of in-article OOVs in the batch
150 | self.art_oovs:
151 | list of list of in-article OOVs (strings), for each example in the batch
152 | self.enc_batch_extend_vocab:
153 | Same as self.enc_batch, but in-article OOVs are represented by their temporary article OOV number.
154 | """
155 | # Determine the maximum length of the encoder input sequence in this batch
156 | max_enc_seq_len = max([ex.enc_len for ex in example_list])
157 |
158 | # Pad the encoder input sequences up to the length of the longest sequence
159 | for ex in example_list:
160 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id)
161 |
162 | # Initialize the numpy arrays
163 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
164 | self.enc_batch = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
165 | self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32)
166 | self.enc_padding_mask = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.float32)
167 |
168 | # Fill in the numpy arrays
169 | for i, ex in enumerate(example_list):
170 | self.enc_batch[i, :] = ex.enc_input[:]
171 | self.enc_lens[i] = ex.enc_len
172 | for j in range(ex.enc_len):
173 | self.enc_padding_mask[i][j] = 1
174 |
175 | # For pointer-generator mode, need to store some extra info
176 | if hps.pointer_gen:
177 | # Determine the max number of in-article OOVs in this batch
178 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
179 | # Store the in-article OOVs themselves
180 | self.art_oovs = [ex.article_oovs for ex in example_list]
181 | # Store the version of the enc_batch that uses the article OOV ids
182 | self.enc_batch_extend_vocab = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
183 | for i, ex in enumerate(example_list):
184 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]
185 |
186 | def init_decoder_seq(self, example_list, hps):
187 | """Initializes the following:
188 | self.dec_batch:
189 | numpy array of shape (batch_size, max_dec_steps), containing integer ids as input for the decoder, padded to max_dec_steps length.
190 | self.target_batch:
191 | numpy array of shape (batch_size, max_dec_steps), containing integer ids for the target sequence, padded to max_dec_steps length.
192 | self.dec_padding_mask:
193 | numpy array of shape (batch_size, max_dec_steps), containing 1s and 0s. 1s correspond to real tokens in dec_batch and target_batch; 0s correspond to padding.
194 | """
195 | # Pad the inputs and targets
196 | for ex in example_list:
197 | ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id)
198 |
199 | # Initialize the numpy arrays.
200 | # Note: our decoder inputs and targets must be the same length for each batch (second dimension = max_dec_steps) because we do not use a dynamic_rnn for decoding. However I believe this is possible, or will soon be possible, with Tensorflow 1.0, in which case it may be best to upgrade to that.
201 | self.dec_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
202 | self.target_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
203 | self.dec_padding_mask = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.float32)
204 |
205 | # Fill in the numpy arrays
206 | for i, ex in enumerate(example_list):
207 | self.dec_batch[i, :] = ex.dec_input[:]
208 | self.target_batch[i, :] = ex.target[:]
209 | for j in range(ex.dec_len):
210 | self.dec_padding_mask[i][j] = 1
211 |
212 | def store_orig_strings(self, example_list):
213 | """Store the original article and abstract strings in the Batch object"""
214 | self.original_articles = [ex.original_article for ex in example_list] # list of lists
215 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
216 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists
217 |
218 |
219 | class Batcher(object):
220 | """A class to generate minibatches of data. Buckets examples together based on length of the encoder sequence."""
221 |
222 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
223 |
224 | def __init__(self, data_path, vocab, hps, single_pass):
225 | """Initialize the batcher. Start threads that process the data into batches.
226 |
227 | Args:
228 | data_path: tf.Example filepattern.
229 | vocab: Vocabulary object
230 | hps: hyperparameters
231 | single_pass: If True, run through the dataset exactly once (useful for when you want to run evaluation on the dev or test set). Otherwise generate random batches indefinitely (useful for training).
232 | """
233 | self._data_path = data_path
234 | self._vocab = vocab
235 | self._hps = hps
236 | self._single_pass = single_pass
237 |
238 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
239 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
240 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size)
241 |
242 | # Different settings depending on whether we're in single_pass mode or not
243 | if single_pass:
244 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
245 | self._num_batch_q_threads = 1 # just one thread to batch examples
246 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
247 | self._finished_reading = False # this will tell us when we're finished reading the dataset
248 | else:
249 | self._num_example_q_threads = 16 # num threads to fill example queue
250 | self._num_batch_q_threads = 4 # num threads to fill batch queue
251 | self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing
252 |
253 | # Start the threads that load the queues
254 | self._example_q_threads = []
255 | for _ in range(self._num_example_q_threads):
256 | self._example_q_threads.append(Thread(target=self.fill_example_queue))
257 | self._example_q_threads[-1].daemon = True
258 | self._example_q_threads[-1].start()
259 | self._batch_q_threads = []
260 | for _ in range(self._num_batch_q_threads):
261 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
262 | self._batch_q_threads[-1].daemon = True
263 | self._batch_q_threads[-1].start()
264 |
265 | # Start a thread that watches the other threads and restarts them if they're dead
266 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
267 | self._watch_thread = Thread(target=self.watch_threads)
268 | self._watch_thread.daemon = True
269 | self._watch_thread.start()
270 |
271 | def next_batch(self):
272 | """Return a Batch from the batch queue.
273 |
274 | If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search.
275 |
276 | Returns:
277 | batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset.
278 | """
279 | # If the batch queue is empty, print a warning
280 | if self._batch_queue.qsize() == 0:
281 | tf.logging.warning(
282 | 'Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i',
283 | self._batch_queue.qsize(), self._example_queue.qsize())
284 | if self._single_pass and self._finished_reading:
285 | tf.logging.info("Finished reading dataset in single_pass mode.")
286 | return None
287 |
288 | batch = self._batch_queue.get() # get the next Batch
289 | return batch
290 |
291 | def fill_example_queue(self):
292 | """Reads data from file and processes into Examples which are then placed into the example queue."""
293 |
294 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass))
295 |
296 | while True:
297 | try:
298 | # (article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings.
299 | (article) = next(input_gen) # read the next example from file. article and abstract are both strings.
300 | except StopIteration: # if there are no more examples:
301 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
302 | if self._single_pass:
303 | tf.logging.info(
304 | "single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
305 | self._finished_reading = True
306 | break
307 | else:
308 | raise Exception("single_pass mode is off but the example generator is out of data; error.")
309 |
310 | # abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the and tags in abstract to get a list of sentences.
311 | example = Example(article, article, self._vocab, self._hps) # Process into an Example.
312 | self._example_queue.put(example) # place the Example in the example queue.
313 |
314 | def fill_batch_queue(self):
315 | """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue.
316 |
317 | In decode mode, makes batches that each contain a single example repeated.
318 | """
319 | while True:
320 | if self._hps.mode != 'decode':
321 | # Get bucketing_cache_size-many batches of Examples into a list, then sort
322 | inputs = []
323 | for _ in range(self._hps.batch_size * self._bucketing_cache_size):
324 | inputs.append(self._example_queue.get())
325 | inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence
326 |
327 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
328 | batches = []
329 | for i in range(0, len(inputs), self._hps.batch_size):
330 | batches.append(inputs[i:i + self._hps.batch_size])
331 | if not self._single_pass:
332 | shuffle(batches)
333 | for b in batches: # each b is a list of Example objects
334 | self._batch_queue.put(Batch(b, self._hps, self._vocab))
335 |
336 | else: # beam search decode mode
337 | ex = self._example_queue.get()
338 | b = [ex for _ in range(self._hps.batch_size)]
339 | self._batch_queue.put(Batch(b, self._hps, self._vocab))
340 |
341 | def watch_threads(self):
342 | """Watch example queue and batch queue threads and restart if dead."""
343 | while True:
344 | time.sleep(60)
345 | for idx, t in enumerate(self._example_q_threads):
346 | if not t.is_alive(): # if the thread is dead
347 | tf.logging.error('Found example queue thread dead. Restarting.')
348 | new_t = Thread(target=self.fill_example_queue)
349 | self._example_q_threads[idx] = new_t
350 | new_t.daemon = True
351 | new_t.start()
352 | for idx, t in enumerate(self._batch_q_threads):
353 | if not t.is_alive(): # if the thread is dead
354 | tf.logging.error('Found batch queue thread dead. Restarting.')
355 | new_t = Thread(target=self.fill_batch_queue)
356 | self._batch_q_threads[idx] = new_t
357 | new_t.daemon = True
358 | new_t.start()
359 |
360 | def text_generator(self, example_generator):
361 | """Generates article and abstract text from tf.Example.
362 |
363 | Args:
364 | example_generator: a generator of tf.Examples from file. See data.example_generator"""
365 | while True:
366 | e = next(example_generator) # e is a tf.Example
367 | try:
368 | article_text = e.features.feature['article'].bytes_list.value[
369 | 0].decode() # the article text was saved under the key 'article' in the data files
370 | # abstract_text = e.features.feature['abstract'].bytes_list.value[0].decode() # the abstract text was saved under the key 'abstract' in the data files
371 | except ValueError:
372 | tf.logging.error('Failed to get article or abstract from example')
373 | continue
374 | if len(article_text) == 0: # See https://github.com/abisee/pointer-generator/issues/1
375 | tf.logging.warning('Found an example with empty article text. Skipping it.')
376 | else:
377 | # yield (article_text, abstract_text)
378 | yield (article_text)
379 |
--------------------------------------------------------------------------------
/core/getpoint/beam_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding"""
18 |
19 | import tensorflow as tf
20 | import numpy as np
21 | import data
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 | class Hypothesis(object):
26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
27 |
28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage):
29 | """Hypothesis constructor.
30 |
31 | Args:
32 | tokens: List of integers. The ids of the tokens that form the summary so far.
33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
34 | state: Current state of the decoder, a LSTMStateTuple.
35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far.
36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far.
37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector.
38 | """
39 | self.tokens = tokens
40 | self.log_probs = log_probs
41 | self.state = state
42 | self.attn_dists = attn_dists
43 | self.p_gens = p_gens
44 | self.coverage = coverage
45 |
46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage):
47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search.
48 |
49 | Args:
50 | token: Integer. Latest token produced by beam search.
51 | log_prob: Float. Log prob of the latest token.
52 | state: Current decoder state, a LSTMStateTuple.
53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length).
54 | p_gen: Generation probability on latest step. Float.
55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage.
56 | Returns:
57 | New Hypothesis for next step.
58 | """
59 | return Hypothesis(tokens = self.tokens + [token],
60 | log_probs = self.log_probs + [log_prob],
61 | state = state,
62 | attn_dists = self.attn_dists + [attn_dist],
63 | p_gens = self.p_gens + [p_gen],
64 | coverage = coverage)
65 |
66 | @property
67 | def latest_token(self):
68 | return self.tokens[-1]
69 |
70 | @property
71 | def log_prob(self):
72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
73 | return sum(self.log_probs)
74 |
75 | @property
76 | def avg_log_prob(self):
77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
78 | return self.log_prob / len(self.tokens)
79 |
80 |
81 | def run_beam_search(sess, model, vocab, batch):
82 | """Performs beam search decoding on the given example.
83 |
84 | Args:
85 | sess: a tf.Session
86 | model: a seq2seq model
87 | vocab: Vocabulary object
88 | batch: Batch object that is the same example repeated across the batch
89 |
90 | Returns:
91 | best_hyp: Hypothesis object; the best hypothesis found by beam search.
92 | """
93 | # Run the encoder to get the encoder hidden states and decoder initial state
94 | enc_states, dec_in_state = model.run_encoder(sess, batch)
95 | # dec_in_state is a LSTMStateTuple
96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
97 |
98 | # Initialize beam_size-many hyptheses
99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)],
100 | log_probs=[0.0],
101 | state=dec_in_state,
102 | attn_dists=[],
103 | p_gens=[],
104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
105 | ) for _ in range(FLAGS.beam_size)]
106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
107 |
108 | steps = 0
109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses
113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
114 |
115 | # Run one step of the decoder to get the new info
116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess,
117 | batch=batch,
118 | latest_tokens=latest_tokens,
119 | enc_states=enc_states,
120 | dec_init_states=states,
121 | prev_coverage=prev_coverage)
122 |
123 | # Extend each hypothesis and collect them all in all_hyps
124 | all_hyps = []
125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
126 | for i in range(num_orig_hyps):
127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
129 | # Extend the ith hypothesis with the jth option
130 | new_hyp = h.extend(token=topk_ids[i, j],
131 | log_prob=topk_log_probs[i, j],
132 | state=new_state,
133 | attn_dist=attn_dist,
134 | p_gen=p_gen,
135 | coverage=new_coverage_i)
136 | all_hyps.append(new_hyp)
137 |
138 | # Filter and collect any hypotheses that have produced the end token.
139 | hyps = [] # will contain hypotheses for the next step
140 | for h in sort_hyps(all_hyps): # in order of most likely h
141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard.
143 | if steps >= FLAGS.min_dec_steps:
144 | results.append(h)
145 | else: # hasn't reached stop token, so continue to extend this hypothesis
146 | hyps.append(h)
147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
149 | break
150 |
151 | steps += 1
152 |
153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps
154 |
155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
156 | results = hyps
157 |
158 | # Sort hypotheses by average log probability
159 | hyps_sorted = sort_hyps(results)
160 |
161 | # Return the hypothesis with highest average log prob
162 | return hyps_sorted[0]
163 |
164 | def sort_hyps(hyps):
165 | """Return a list of Hypothesis objects, sorted by descending average log probability"""
166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)
167 |
--------------------------------------------------------------------------------
/core/getpoint/convert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import struct
5 | import tensorflow as tf
6 | from tensorflow.core.example import example_pb2
7 |
8 |
9 | END_TOKENS = frozenset(['.', '!', '?', '...', "'", "`", '"', ")"]) # acceptable ways to end a sentence
10 |
11 | def fix_missing_period(line):
12 | """Adds a period to a line that is missing a period"""
13 | if "@highlight" in line: return line
14 | if line=="": return line
15 | if line[-1] in END_TOKENS: return line
16 | return line + " ."
17 |
18 | def get_art_abs(input_string):
19 | lines = input_string.splitlines()
20 | lines = [line.lower() for line in lines]
21 | lines = [fix_missing_period(line) for line in lines]
22 |
23 | article_lines = []
24 | highlights = []
25 | next_is_highlight = False
26 | for idx,line in enumerate(lines):
27 | if line == "":
28 | continue # no line
29 | elif line.startswith("@highlight"):
30 | next_is_highlight = True
31 | elif next_is_highlight:
32 | highlights.append(line)
33 | else:
34 | article_lines.append(line)
35 |
36 | # To a string
37 | article = ' '.join(article_lines)
38 |
39 | return article
40 |
41 | def convert_to_bin(input_string, out_file):
42 |
43 | with open(out_file, 'wb') as writer:
44 | # start to write .bin file
45 | article = get_art_abs(input_string)
46 |
47 | article=tf.compat.as_bytes(article, encoding='utf-8')
48 | # tf.Example write
49 | tf_example = example_pb2.Example()
50 | tf_example.features.feature['article'].bytes_list.value.extend([article])
51 | tf_example_str = tf_example.SerializeToString()
52 | str_len = len(tf_example_str)
53 | writer.write(struct.pack('q', str_len))
54 | writer.write(struct.pack('%ds' % str_len, tf_example_str))
55 |
--------------------------------------------------------------------------------
/core/getpoint/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
18 |
19 | import glob
20 | import random
21 | import struct
22 | import csv
23 | from tensorflow.core.example import example_pb2
24 |
25 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
26 | SENTENCE_START = ''
27 | SENTENCE_END = ''
28 |
29 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence # nosec - B105:hardcoded_password_string]
30 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words # nosec - B105:hardcoded_password_string]
31 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
32 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
33 |
34 |
35 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
36 |
37 |
38 | class Vocab(object):
39 | """Vocabulary class for mapping between words and ids (integers)"""
40 |
41 | def __init__(self, vocab_file, max_size):
42 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
43 |
44 | Args:
45 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
46 | max_size: integer. The maximum size of the resulting Vocabulary."""
47 | self._word_to_id = {}
48 | self._id_to_word = {}
49 | self._count = 0 # keeps track of total number of words in the Vocab
50 |
51 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
52 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
53 | self._word_to_id[w] = self._count
54 | self._id_to_word[self._count] = w
55 | self._count += 1
56 |
57 | # Read the vocab file and add words up to max_size
58 | with open(vocab_file, 'r') as vocab_f:
59 | for line in vocab_f:
60 | pieces = line.split()
61 | if len(pieces) != 2:
62 | # print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
63 | continue
64 | w = pieces[0]
65 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
66 | raise Exception(
67 | ', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
68 | if w in self._word_to_id:
69 | raise Exception('Duplicated word in vocabulary file: %s' % w)
70 | self._word_to_id[w] = self._count
71 | self._id_to_word[self._count] = w
72 | self._count += 1
73 | if max_size != 0 and self._count >= max_size:
74 | # print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
75 | break
76 |
77 | # print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))
78 |
79 | def word2id(self, word):
80 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
81 | if word not in self._word_to_id:
82 | return self._word_to_id[UNKNOWN_TOKEN]
83 | return self._word_to_id[word]
84 |
85 | def id2word(self, word_id):
86 | """Returns the word (string) corresponding to an id (integer)."""
87 | if word_id not in self._id_to_word:
88 | raise ValueError('Id not found in vocab: %d' % word_id)
89 | return self._id_to_word[word_id]
90 |
91 | def size(self):
92 | """Returns the total size of the vocabulary"""
93 | return self._count
94 |
95 | def write_metadata(self, fpath):
96 | """Writes metadata file for Tensorboard word embedding visualizer as described here:
97 | https://www.tensorflow.org/get_started/embedding_viz
98 |
99 | Args:
100 | fpath: place to write the metadata file
101 | """
102 | # print("Writing word embedding metadata file to %s..." % (fpath))
103 | with open(fpath, "w") as f:
104 | fieldnames = ['word']
105 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
106 | for i in range(self.size()):
107 | writer.writerow({"word": self._id_to_word[i]})
108 |
109 |
110 | def example_generator(data_path, single_pass):
111 | """Generates tf.Examples from data files.
112 |
113 | Binary data format: . represents the byte size
114 | of . is serialized tf.Example proto. The tf.Example contains
115 | the tokenized article text and summary.
116 |
117 | Args:
118 | data_path:
119 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all.
120 | single_pass:
121 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely.
122 |
123 | Yields:
124 | Deserialized tf.Example.
125 | """
126 | while True:
127 | filelist = glob.glob(data_path) # get the list of datafiles
128 | if not filelist:
129 | raise ValueError('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
130 | if single_pass:
131 | filelist = sorted(filelist)
132 | else:
133 | random.shuffle(filelist)
134 | for f in filelist:
135 | reader = open(f, 'rb')
136 | while True:
137 | len_bytes = reader.read(8)
138 | if not len_bytes: break # finished reading this file
139 | str_len = struct.unpack('q', len_bytes)[0]
140 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
141 | yield example_pb2.Example.FromString(example_str)
142 | if single_pass:
143 | # print("example_generator completed reading all datafiles. No more data.")
144 | break
145 |
146 |
147 | def article2ids(article_words, vocab):
148 | """Map the article words to their ids. Also return a list of OOVs in the article.
149 |
150 | Args:
151 | article_words: list of words (strings)
152 | vocab: Vocabulary object
153 |
154 | Returns:
155 | ids:
156 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002.
157 | oovs:
158 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers."""
159 | ids = []
160 | oovs = []
161 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
162 | for w in article_words:
163 | i = vocab.word2id(w)
164 | if i == unk_id: # If w is OOV
165 | if w not in oovs: # Add to list of OOVs
166 | oovs.append(w)
167 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
168 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
169 | else:
170 | ids.append(i)
171 | return ids, oovs
172 |
173 |
174 | def abstract2ids(abstract_words, vocab, article_oovs):
175 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers.
176 |
177 | Args:
178 | abstract_words: list of words (strings)
179 | vocab: Vocabulary object
180 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers
181 |
182 | Returns:
183 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id."""
184 | ids = []
185 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
186 | for w in abstract_words:
187 | i = vocab.word2id(w)
188 | if i == unk_id: # If w is an OOV word
189 | if w in article_oovs: # If w is an in-article OOV
190 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
191 | ids.append(vocab_idx)
192 | else: # If w is an out-of-article OOV
193 | ids.append(unk_id) # Map to the UNK token id
194 | else:
195 | ids.append(i)
196 | return ids
197 |
198 |
199 | def outputids2words(id_list, vocab, article_oovs):
200 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode).
201 |
202 | Args:
203 | id_list: list of ids (integers)
204 | vocab: Vocabulary object
205 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode)
206 |
207 | Returns:
208 | words: list of words (strings)
209 | """
210 | words = []
211 | for i in id_list:
212 | try:
213 | w = vocab.id2word(i) # might be [UNK]
214 | except ValueError as e: # w is OOV
215 | if article_oovs is None:
216 | raise ValueError(
217 | "Error: model produced a word ID that isn't in the vocabulary. This should not happen in"
218 | "baseline (no pointer-generator) mode")
219 | article_oov_idx = i - vocab.size()
220 | try:
221 | w = article_oovs[article_oov_idx]
222 | except ValueError as e: # i doesn't correspond to an article oov
223 | raise ValueError(
224 | 'Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (
225 | i, article_oov_idx, len(article_oovs)))
226 | words.append(w)
227 | return words
228 |
229 |
230 | def abstract2sents(abstract):
231 | """Splits abstract text from datafile into list of sentences.
232 |
233 | Args:
234 | abstract: string containing and tags for starts and ends of sentences
235 |
236 | Returns:
237 | sents: List of sentence strings (no tags)"""
238 | cur = 0
239 | sents = []
240 | while True:
241 | try:
242 | start_p = abstract.index(SENTENCE_START, cur)
243 | end_p = abstract.index(SENTENCE_END, start_p + 1)
244 | cur = end_p + len(SENTENCE_END)
245 | sents.append(abstract[start_p + len(SENTENCE_START):end_p])
246 | except ValueError as e: # no more sentences
247 | return sents
248 |
249 |
250 | def show_art_oovs(article, vocab):
251 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them"""
252 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
253 | words = article.split(' ')
254 | words = [("__%s__" % w) if vocab.word2id(w) == unk_token else w for w in words]
255 | out_str = ' '.join(words)
256 | return out_str
257 |
258 |
259 | def show_abs_oovs(abstract, vocab, article_oovs):
260 | """Returns the abstract string, highlighting the article OOVs with __underscores__.
261 |
262 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!.
263 |
264 | Args:
265 | abstract: string
266 | vocab: Vocabulary object
267 | article_oovs: list of words (strings), or None (in baseline mode)
268 | """
269 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
270 | words = abstract.split(' ')
271 | new_words = []
272 | for w in words:
273 | if vocab.word2id(w) == unk_token: # w is oov
274 | if article_oovs is None: # baseline mode
275 | new_words.append("__%s__" % w)
276 | else: # pointer-generator mode
277 | if w in article_oovs:
278 | new_words.append("__%s__" % w)
279 | else:
280 | new_words.append("!!__%s__!!" % w)
281 | else: # w is in-vocab word
282 | new_words.append(w)
283 | out_str = ' '.join(new_words)
284 | return out_str
285 |
--------------------------------------------------------------------------------
/core/getpoint/decode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis"""
18 |
19 | import os
20 | import sys
21 | import time
22 | import tensorflow as tf
23 | import beam_search
24 | import data
25 | import json
26 | import pyrouge
27 | import util
28 | import logging
29 | import numpy as np
30 |
31 | FLAGS = tf.app.flags.FLAGS
32 |
33 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint
34 |
35 |
36 | class BeamSearchDecoder(object):
37 | """Beam search decoder."""
38 |
39 | def __init__(self, model, batcher, vocab):
40 | """Initialize decoder.
41 |
42 | Args:
43 | model: a Seq2SeqAttentionModel object.
44 | batcher: a Batcher object.
45 | vocab: Vocabulary object
46 | """
47 | self._model = model
48 | self._model.build_graph()
49 | self._batcher = batcher
50 | self._vocab = vocab
51 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
52 | self._sess = tf.Session(config=util.get_config())
53 |
54 | # Load an initial checkpoint to use for decoding
55 | ckpt_path = util.load_ckpt(self._saver, self._sess)
56 |
57 |
58 | # if FLAGS.single_pass:
59 | # # Make a descriptive decode directory name
60 | # ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456"
61 | # self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
62 | # if os.path.exists(self._decode_dir):
63 | # raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)
64 | #
65 | # else: # Generic decode dir name
66 | self._decode_dir = os.path.join(FLAGS.log_root, "decode")
67 |
68 | # Make the decode dir if necessary
69 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)
70 |
71 | # if FLAGS.single_pass:
72 | # # Make the dirs to contain output written in the correct format for pyrouge
73 | # self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
74 | # if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
75 | # self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
76 | # if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
77 |
78 |
79 | def decode(self):
80 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
81 | # t0 = time.time()
82 | batch = self._batcher.next_batch() # 1 example repeated across batch
83 |
84 | original_article = batch.original_articles[0] # string
85 | original_abstract = batch.original_abstracts[0] # string
86 |
87 | # input data
88 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string
89 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
90 |
91 | # Run beam search to get best Hypothesis
92 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
93 |
94 | # Extract the output ids from the hypothesis and convert back to words
95 | output_ids = [int(t) for t in best_hyp.tokens[1:]]
96 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
97 |
98 | # Remove the [STOP] token from decoded_words, if necessary
99 | try:
100 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
101 | decoded_words = decoded_words[:fst_stop_idx]
102 | except ValueError:
103 | decoded_words = decoded_words
104 | decoded_output = ' '.join(decoded_words) # single string
105 |
106 | # tf.logging.info('ARTICLE: %s', article)
107 | # tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
108 |
109 | sys.stdout.write(decoded_output)
110 |
--------------------------------------------------------------------------------
/core/getpoint/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this:
3 | python inspect_checkpoint.py model.12345
4 | """
5 |
6 | import tensorflow as tf
7 | import sys
8 | import numpy as np
9 |
10 |
11 | if __name__ == '__main__':
12 | if len(sys.argv) != 2:
13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.")
14 | file_name = sys.argv[1]
15 | reader = tf.train.NewCheckpointReader(file_name)
16 | var_to_shape_map = reader.get_variable_to_shape_map()
17 |
18 | finite = []
19 | all_infnan = []
20 | some_infnan = []
21 |
22 | for key in sorted(var_to_shape_map.keys()):
23 | tensor = reader.get_tensor(key)
24 | if np.all(np.isfinite(tensor)):
25 | finite.append(key)
26 | else:
27 | if not np.any(np.isfinite(tensor)):
28 | all_infnan.append(key)
29 | else:
30 | some_infnan.append(key)
31 |
32 | print("\nFINITE VARIABLES:")
33 | for key in finite: print(key)
34 |
35 | print("\nVARIABLES THAT ARE ALL INF/NAN:")
36 | for key in all_infnan: print(key)
37 |
38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:")
39 | for key in some_infnan: print(key)
40 |
41 | if not all_infnan and not some_infnan:
42 | print("CHECK PASSED: checkpoint contains no inf/NaN values")
43 | else:
44 | print("CHECK FAILED: checkpoint contains some inf/NaN values")
45 |
--------------------------------------------------------------------------------
/core/getpoint/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to build and run the tensorflow graph for the sequence-to-sequence model"""
18 |
19 | import os
20 | import time
21 | import numpy as np
22 | import tensorflow as tf
23 | from attention_decoder import attention_decoder
24 | from tensorflow.contrib.tensorboard.plugins import projector
25 |
26 | FLAGS = tf.app.flags.FLAGS
27 |
28 |
29 | class SummarizationModel(object):
30 | """A class to represent a sequence-to-sequence model for text summarization. Supports both baseline mode, pointer-generator mode, and coverage"""
31 |
32 | def __init__(self, hps, vocab):
33 | self._hps = hps
34 | self._vocab = vocab
35 |
36 | def _add_placeholders(self):
37 | """Add placeholders to the graph. These are entry points for any input data."""
38 | hps = self._hps
39 |
40 | # encoder part
41 | self._enc_batch = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch')
42 | self._enc_lens = tf.placeholder(tf.int32, [hps.batch_size], name='enc_lens')
43 | self._enc_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, None], name='enc_padding_mask')
44 | if FLAGS.pointer_gen:
45 | self._enc_batch_extend_vocab = tf.placeholder(tf.int32, [hps.batch_size, None],
46 | name='enc_batch_extend_vocab')
47 | self._max_art_oovs = tf.placeholder(tf.int32, [], name='max_art_oovs')
48 |
49 | # decoder part
50 | self._dec_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='dec_batch')
51 | self._target_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='target_batch')
52 | self._dec_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, hps.max_dec_steps],
53 | name='dec_padding_mask')
54 |
55 | if hps.mode == "decode" and hps.coverage:
56 | self.prev_coverage = tf.placeholder(tf.float32, [hps.batch_size, None], name='prev_coverage')
57 |
58 | def _make_feed_dict(self, batch, just_enc=False):
59 | """Make a feed dictionary mapping parts of the batch to the appropriate placeholders.
60 |
61 | Args:
62 | batch: Batch object
63 | just_enc: Boolean. If True, only feed the parts needed for the encoder.
64 | """
65 | feed_dict = {}
66 | feed_dict[self._enc_batch] = batch.enc_batch
67 | feed_dict[self._enc_lens] = batch.enc_lens
68 | feed_dict[self._enc_padding_mask] = batch.enc_padding_mask
69 | if FLAGS.pointer_gen:
70 | feed_dict[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
71 | feed_dict[self._max_art_oovs] = batch.max_art_oovs
72 | if not just_enc:
73 | feed_dict[self._dec_batch] = batch.dec_batch
74 | feed_dict[self._target_batch] = batch.target_batch
75 | feed_dict[self._dec_padding_mask] = batch.dec_padding_mask
76 | return feed_dict
77 |
78 | def _add_encoder(self, encoder_inputs, seq_len):
79 | """Add a single-layer bidirectional LSTM encoder to the graph.
80 |
81 | Args:
82 | encoder_inputs: A tensor of shape [batch_size, <=max_enc_steps, emb_size].
83 | seq_len: Lengths of encoder_inputs (before padding). A tensor of shape [batch_size].
84 |
85 | Returns:
86 | encoder_outputs:
87 | A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim]. It's 2*hidden_dim because it's the concatenation of the forwards and backwards states.
88 | fw_state, bw_state:
89 | Each are LSTMStateTuples of shape ([batch_size,hidden_dim],[batch_size,hidden_dim])
90 | """
91 | with tf.variable_scope('encoder'):
92 | cell_fw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init,
93 | state_is_tuple=True)
94 | cell_bw = tf.contrib.rnn.LSTMCell(self._hps.hidden_dim, initializer=self.rand_unif_init,
95 | state_is_tuple=True)
96 | (encoder_outputs, (fw_st, bw_st)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_inputs,
97 | dtype=tf.float32,
98 | sequence_length=seq_len,
99 | swap_memory=True)
100 | encoder_outputs = tf.concat(axis=2, values=encoder_outputs) # concatenate the forwards and backwards states
101 | return encoder_outputs, fw_st, bw_st
102 |
103 | def _reduce_states(self, fw_st, bw_st):
104 | """Add to the graph a linear layer to reduce the encoder's final FW and BW state into a single initial state for the decoder. This is needed because the encoder is bidirectional but the decoder is not.
105 |
106 | Args:
107 | fw_st: LSTMStateTuple with hidden_dim units.
108 | bw_st: LSTMStateTuple with hidden_dim units.
109 |
110 | Returns:
111 | state: LSTMStateTuple with hidden_dim units.
112 | """
113 | hidden_dim = self._hps.hidden_dim
114 | with tf.variable_scope('reduce_final_st'):
115 | # Define weights and biases to reduce the cell and reduce the state
116 | w_reduce_c = tf.get_variable('w_reduce_c', [hidden_dim * 2, hidden_dim], dtype=tf.float32,
117 | initializer=self.trunc_norm_init)
118 | w_reduce_h = tf.get_variable('w_reduce_h', [hidden_dim * 2, hidden_dim], dtype=tf.float32,
119 | initializer=self.trunc_norm_init)
120 | bias_reduce_c = tf.get_variable('bias_reduce_c', [hidden_dim], dtype=tf.float32,
121 | initializer=self.trunc_norm_init)
122 | bias_reduce_h = tf.get_variable('bias_reduce_h', [hidden_dim], dtype=tf.float32,
123 | initializer=self.trunc_norm_init)
124 |
125 | # Apply linear layer
126 | old_c = tf.concat(axis=1, values=[fw_st.c, bw_st.c]) # Concatenation of fw and bw cell
127 | old_h = tf.concat(axis=1, values=[fw_st.h, bw_st.h]) # Concatenation of fw and bw state
128 | new_c = tf.nn.relu(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) # Get new cell from old cell
129 | new_h = tf.nn.relu(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) # Get new state from old state
130 | return tf.contrib.rnn.LSTMStateTuple(new_c, new_h) # Return new cell and state
131 |
132 | def _add_decoder(self, inputs):
133 | """Add attention decoder to the graph. In train or eval mode, you call this once to get output on ALL steps. In decode (beam search) mode, you call this once for EACH decoder step.
134 |
135 | Args:
136 | inputs: inputs to the decoder (word embeddings). A list of tensors shape (batch_size, emb_dim)
137 |
138 | Returns:
139 | outputs: List of tensors; the outputs of the decoder
140 | out_state: The final state of the decoder
141 | attn_dists: A list of tensors; the attention distributions
142 | p_gens: A list of scalar tensors; the generation probabilities
143 | coverage: A tensor, the current coverage vector
144 | """
145 | hps = self._hps
146 | cell = tf.contrib.rnn.LSTMCell(hps.hidden_dim, state_is_tuple=True, initializer=self.rand_unif_init)
147 |
148 | prev_coverage = self.prev_coverage if hps.mode == "decode" and hps.coverage else None # In decode mode, we run attention_decoder one step at a time and so need to pass in the previous step's coverage vector each time
149 |
150 | outputs, out_state, attn_dists, p_gens, coverage = attention_decoder(inputs, self._dec_in_state,
151 | self._enc_states, self._enc_padding_mask,
152 | cell, initial_state_attention=(
153 | hps.mode == "decode"), pointer_gen=hps.pointer_gen, use_coverage=hps.coverage,
154 | prev_coverage=prev_coverage)
155 |
156 | return outputs, out_state, attn_dists, p_gens, coverage
157 |
158 | def _calc_final_dist(self, vocab_dists, attn_dists):
159 | """Calculate the final distribution, for the pointer-generator model
160 |
161 | Args:
162 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.
163 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
164 |
165 | Returns:
166 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
167 | """
168 | with tf.variable_scope('final_distribution'):
169 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
170 | vocab_dists = [p_gen * dist for (p_gen, dist) in zip(self.p_gens, vocab_dists)]
171 | attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(self.p_gens, attn_dists)]
172 |
173 | # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
174 | extended_vsize = self._vocab.size() + self._max_art_oovs # the maximum (over the batch) size of the extended vocabulary
175 | extra_zeros = tf.zeros((self._hps.batch_size, self._max_art_oovs))
176 | vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in
177 | vocab_dists] # list length max_dec_steps of shape (batch_size, extended_vsize)
178 |
179 | # Project the values in the attention distributions onto the appropriate entries in the final distributions
180 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution
181 | # This is done for each decoder timestep.
182 | # This is fiddly; we use tf.scatter_nd to do the projection
183 | batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
184 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1)
185 | attn_len = tf.shape(self._enc_batch_extend_vocab)[1] # number of states we attend over
186 | batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len)
187 | indices = tf.stack((batch_nums, self._enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2)
188 | shape = [self._hps.batch_size, extended_vsize]
189 | attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in
190 | attn_dists] # list length max_dec_steps (batch_size, extended_vsize)
191 |
192 | # Add the vocab distributions and the copy distributions together to get the final distributions
193 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep
194 | # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
195 | final_dists = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in
196 | zip(vocab_dists_extended, attn_dists_projected)]
197 |
198 | return final_dists
199 |
200 | def _add_emb_vis(self, embedding_var):
201 | """Do setup so that we can view word embedding visualization in Tensorboard, as described here:
202 | https://www.tensorflow.org/get_started/embedding_viz
203 | Make the vocab metadata file, then make the projector config file pointing to it."""
204 | train_dir = os.path.join(FLAGS.log_root, "train")
205 | vocab_metadata_path = os.path.join(train_dir, "vocab_metadata.tsv")
206 | self._vocab.write_metadata(vocab_metadata_path) # write metadata file
207 | summary_writer = tf.summary.FileWriter(train_dir)
208 | config = projector.ProjectorConfig()
209 | embedding = config.embeddings.add()
210 | embedding.tensor_name = embedding_var.name
211 | embedding.metadata_path = vocab_metadata_path
212 | projector.visualize_embeddings(summary_writer, config)
213 |
214 | def _add_seq2seq(self):
215 | """Add the whole sequence-to-sequence model to the graph."""
216 | hps = self._hps
217 | vsize = self._vocab.size() # size of the vocabulary
218 |
219 | with tf.variable_scope('seq2seq'):
220 | # Some initializers
221 | self.rand_unif_init = tf.random_uniform_initializer(-hps.rand_unif_init_mag, hps.rand_unif_init_mag,
222 | seed=123)
223 | self.trunc_norm_init = tf.truncated_normal_initializer(stddev=hps.trunc_norm_init_std)
224 |
225 | # Add embedding matrix (shared by the encoder and decoder inputs)
226 | with tf.variable_scope('embedding'):
227 | embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32,
228 | initializer=self.trunc_norm_init)
229 | if hps.mode == "train": self._add_emb_vis(embedding) # add to tensorboard
230 | emb_enc_inputs = tf.nn.embedding_lookup(embedding,
231 | self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size)
232 | emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch,
233 | axis=1)] # list length max_dec_steps containing shape (batch_size, emb_size)
234 |
235 | # Add the encoder.
236 | enc_outputs, fw_st, bw_st = self._add_encoder(emb_enc_inputs, self._enc_lens)
237 | self._enc_states = enc_outputs
238 |
239 | # Our encoder is bidirectional and our decoder is unidirectional so we need to reduce the final encoder hidden state to the right size to be the initial decoder hidden state
240 | self._dec_in_state = self._reduce_states(fw_st, bw_st)
241 |
242 | # Add the decoder.
243 | with tf.variable_scope('decoder'):
244 | decoder_outputs, self._dec_out_state, self.attn_dists, self.p_gens, self.coverage = self._add_decoder(
245 | emb_dec_inputs)
246 |
247 | # Add the output projection to obtain the vocabulary distribution
248 | with tf.variable_scope('output_projection'):
249 | w = tf.get_variable('w', [hps.hidden_dim, vsize], dtype=tf.float32, initializer=self.trunc_norm_init)
250 | w_t = tf.transpose(w)
251 | v = tf.get_variable('v', [vsize], dtype=tf.float32, initializer=self.trunc_norm_init)
252 | vocab_scores = [] # vocab_scores is the vocabulary distribution before applying softmax. Each entry on the list corresponds to one decoder step
253 | for i, output in enumerate(decoder_outputs):
254 | if i > 0:
255 | tf.get_variable_scope().reuse_variables()
256 | vocab_scores.append(tf.nn.xw_plus_b(output, w, v)) # apply the linear layer
257 |
258 | vocab_dists = [tf.nn.softmax(s) for s in
259 | vocab_scores] # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.
260 |
261 | # For pointer-generator model, calc final distribution from copy distribution and vocabulary distribution
262 | if FLAGS.pointer_gen:
263 | final_dists = self._calc_final_dist(vocab_dists, self.attn_dists)
264 | else: # final distribution is just vocabulary distribution
265 | final_dists = vocab_dists
266 |
267 | if hps.mode in ['train', 'eval']:
268 | # Calculate the loss
269 | with tf.variable_scope('loss'):
270 | if FLAGS.pointer_gen:
271 | # Calculate the loss per step
272 | # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words
273 | loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size)
274 | batch_nums = tf.range(0, limit=hps.batch_size) # shape (batch_size)
275 | for dec_step, dist in enumerate(final_dists):
276 | targets = self._target_batch[:,
277 | dec_step] # The indices of the target words. shape (batch_size)
278 | indices = tf.stack((batch_nums, targets), axis=1) # shape (batch_size, 2)
279 | gold_probs = tf.gather_nd(dist,
280 | indices) # shape (batch_size). prob of correct words on this step
281 | losses = -tf.log(gold_probs)
282 | loss_per_step.append(losses)
283 |
284 | # Apply dec_padding_mask and get loss
285 | self._loss = _mask_and_avg(loss_per_step, self._dec_padding_mask)
286 |
287 | else: # baseline model
288 | self._loss = tf.contrib.seq2seq.sequence_loss(tf.stack(vocab_scores, axis=1),
289 | self._target_batch,
290 | self._dec_padding_mask) # this applies softmax internally
291 |
292 | tf.summary.scalar('loss', self._loss)
293 |
294 | # Calculate coverage loss from the attention distributions
295 | if hps.coverage:
296 | with tf.variable_scope('coverage_loss'):
297 | self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask)
298 | tf.summary.scalar('coverage_loss', self._coverage_loss)
299 | self._total_loss = self._loss + hps.cov_loss_wt * self._coverage_loss
300 | tf.summary.scalar('total_loss', self._total_loss)
301 |
302 | if hps.mode == "decode":
303 | # We run decode beam search mode one decoder step at a time
304 | if len(final_dists) != 1: # final_dists is a singleton list containing shape (batch_size, extended_vsize)
305 | raise ValueError("final_dists should be a singleton list containing shape")
306 | final_dists = final_dists[0]
307 | topk_probs, self._topk_ids = tf.nn.top_k(final_dists,
308 | hps.batch_size * 2) # take the k largest probs. note batch_size=beam_size in decode mode
309 | self._topk_log_probs = tf.log(topk_probs)
310 |
311 | def _add_train_op(self):
312 | """Sets self._train_op, the op to run for training."""
313 | # Take gradients of the trainable variables w.r.t. the loss function to minimize
314 | loss_to_minimize = self._total_loss if self._hps.coverage else self._loss
315 | tvars = tf.trainable_variables()
316 | gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)
317 |
318 | # Clip the gradients
319 | with tf.device("/gpu:0"):
320 | grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)
321 |
322 | # Add a summary
323 | tf.summary.scalar('global_norm', global_norm)
324 |
325 | # Apply adagrad optimizer
326 | optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
327 | with tf.device("/gpu:0"):
328 | self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step,
329 | name='train_step')
330 |
331 | def build_graph(self):
332 | """Add the placeholders, model, global step, train_op and summaries to the graph"""
333 | tf.logging.info('Building graph...')
334 | t0 = time.time()
335 | self._add_placeholders()
336 | with tf.device("/gpu:0"):
337 | self._add_seq2seq()
338 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
339 | if self._hps.mode == 'train':
340 | self._add_train_op()
341 | self._summaries = tf.summary.merge_all()
342 | t1 = time.time()
343 | tf.logging.info('Time to build graph: %i seconds', t1 - t0)
344 |
345 | def run_train_step(self, sess, batch):
346 | """Runs one training iteration. Returns a dictionary containing train op, summaries, loss, global_step and (optionally) coverage loss."""
347 | feed_dict = self._make_feed_dict(batch)
348 | to_return = {
349 | 'train_op': self._train_op,
350 | 'summaries': self._summaries,
351 | 'loss': self._loss,
352 | 'global_step': self.global_step,
353 | }
354 | if self._hps.coverage:
355 | to_return['coverage_loss'] = self._coverage_loss
356 | return sess.run(to_return, feed_dict)
357 |
358 | def run_eval_step(self, sess, batch):
359 | """Runs one evaluation iteration. Returns a dictionary containing summaries, loss, global_step and (optionally) coverage loss."""
360 | feed_dict = self._make_feed_dict(batch)
361 | to_return = {
362 | 'summaries': self._summaries,
363 | 'loss': self._loss,
364 | 'global_step': self.global_step,
365 | }
366 | if self._hps.coverage:
367 | to_return['coverage_loss'] = self._coverage_loss
368 | return sess.run(to_return, feed_dict)
369 |
370 | def run_encoder(self, sess, batch):
371 | """For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state.
372 |
373 | Args:
374 | sess: Tensorflow session.
375 | batch: Batch object that is the same example repeated across the batch (for beam search)
376 |
377 | Returns:
378 | enc_states: The encoder states. A tensor of shape [batch_size, <=max_enc_steps, 2*hidden_dim].
379 | dec_in_state: A LSTMStateTuple of shape ([1,hidden_dim],[1,hidden_dim])
380 | """
381 | feed_dict = self._make_feed_dict(batch, just_enc=True) # feed the batch into the placeholders
382 | (enc_states, dec_in_state, global_step) = sess.run([self._enc_states, self._dec_in_state, self.global_step],
383 | feed_dict) # run the encoder
384 |
385 | # dec_in_state is LSTMStateTuple shape ([batch_size,hidden_dim],[batch_size,hidden_dim])
386 | # Given that the batch is a single example repeated, dec_in_state is identical across the batch so we just take the top row.
387 | dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0], dec_in_state.h[0])
388 | return enc_states, dec_in_state
389 |
390 | def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states, prev_coverage):
391 | """For beam search decoding. Run the decoder for one step.
392 |
393 | Args:
394 | sess: Tensorflow session.
395 | batch: Batch object containing single example repeated across the batch
396 | latest_tokens: Tokens to be fed as input into the decoder for this timestep
397 | enc_states: The encoder states.
398 | dec_init_states: List of beam_size LSTMStateTuples; the decoder states from the previous timestep
399 | prev_coverage: List of np arrays. The coverage vectors from the previous timestep. List of None if not using coverage.
400 |
401 | Returns:
402 | ids: top 2k ids. shape [beam_size, 2*beam_size]
403 | probs: top 2k log probabilities. shape [beam_size, 2*beam_size]
404 | new_states: new states of the decoder. a list length beam_size containing
405 | LSTMStateTuples each of shape ([hidden_dim,],[hidden_dim,])
406 | attn_dists: List length beam_size containing lists length attn_length.
407 | p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode.
408 | new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on.
409 | """
410 |
411 | beam_size = len(dec_init_states)
412 |
413 | # Turn dec_init_states (a list of LSTMStateTuples) into a single LSTMStateTuple for the batch
414 | cells = [np.expand_dims(state.c, axis=0) for state in dec_init_states]
415 | hiddens = [np.expand_dims(state.h, axis=0) for state in dec_init_states]
416 | new_c = np.concatenate(cells, axis=0) # shape [batch_size,hidden_dim]
417 | new_h = np.concatenate(hiddens, axis=0) # shape [batch_size,hidden_dim]
418 | new_dec_in_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
419 |
420 | feed = {
421 | self._enc_states: enc_states,
422 | self._enc_padding_mask: batch.enc_padding_mask,
423 | self._dec_in_state: new_dec_in_state,
424 | self._dec_batch: np.transpose(np.array([latest_tokens])),
425 | }
426 |
427 | to_return = {
428 | "ids": self._topk_ids,
429 | "probs": self._topk_log_probs,
430 | "states": self._dec_out_state,
431 | "attn_dists": self.attn_dists
432 | }
433 |
434 | if FLAGS.pointer_gen:
435 | feed[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
436 | feed[self._max_art_oovs] = batch.max_art_oovs
437 | to_return['p_gens'] = self.p_gens
438 |
439 | if self._hps.coverage:
440 | feed[self.prev_coverage] = np.stack(prev_coverage, axis=0)
441 | to_return['coverage'] = self.coverage
442 |
443 | results = sess.run(to_return, feed_dict=feed) # run the decoder step
444 |
445 | # Convert results['states'] (a single LSTMStateTuple) into a list of LSTMStateTuple -- one for each hypothesis
446 | new_states = [tf.contrib.rnn.LSTMStateTuple(results['states'].c[i, :], results['states'].h[i, :]) for i in
447 | range(beam_size)]
448 |
449 | # Convert singleton list containing a tensor to a list of k arrays
450 | if len(results['attn_dists']) != 1:
451 | raise ValueError("len(results['attn_dists']) != 1")
452 | attn_dists = results['attn_dists'][0].tolist()
453 |
454 | if FLAGS.pointer_gen:
455 | # Convert singleton list containing a tensor to a list of k arrays
456 | if len(results['p_gens']) != 1:
457 | raise ValueError("len(results['p_gens']) != 1")
458 | p_gens = results['p_gens'][0].tolist()
459 | else:
460 | p_gens = [None for _ in range(beam_size)]
461 |
462 | # Convert the coverage tensor to a list length k containing the coverage vector for each hypothesis
463 | if FLAGS.coverage:
464 | new_coverage = results['coverage'].tolist()
465 | if len(new_coverage) != beam_size:
466 | raise ValueError("len(new_coverage) != beam_size")
467 | else:
468 | new_coverage = [None for _ in range(beam_size)]
469 |
470 | return results['ids'], results['probs'], new_states, attn_dists, p_gens, new_coverage
471 |
472 |
473 | def _mask_and_avg(values, padding_mask):
474 | """Applies mask to values then returns overall average (a scalar)
475 |
476 | Args:
477 | values: a list length max_dec_steps containing arrays shape (batch_size).
478 | padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s.
479 |
480 | Returns:
481 | a scalar
482 | """
483 |
484 | dec_lens = tf.reduce_sum(padding_mask, axis=1) # shape batch_size. float32
485 | values_per_step = [v * padding_mask[:, dec_step] for dec_step, v in enumerate(values)]
486 | values_per_ex = sum(values_per_step) / dec_lens # shape (batch_size); normalized value for each batch member
487 | return tf.reduce_mean(values_per_ex) # overall average
488 |
489 |
490 | def _coverage_loss(attn_dists, padding_mask):
491 | """Calculates the coverage loss from the attention distributions.
492 |
493 | Args:
494 | attn_dists: The attention distributions for each decoder timestep. A list length max_dec_steps containing shape (batch_size, attn_length)
495 | padding_mask: shape (batch_size, max_dec_steps).
496 |
497 | Returns:
498 | coverage_loss: scalar
499 | """
500 | coverage = tf.zeros_like(attn_dists[0]) # shape (batch_size, attn_length). Initial coverage is zero.
501 | covlosses = [] # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).
502 | for a in attn_dists:
503 | covloss = tf.reduce_sum(tf.minimum(a, coverage), [1]) # calculate the coverage loss for this step
504 | covlosses.append(covloss)
505 | coverage += a # update the coverage vector
506 | coverage_loss = _mask_and_avg(covlosses, padding_mask)
507 | return coverage_loss
508 |
--------------------------------------------------------------------------------
/core/getpoint/run_summarization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This is the top-level file to train, evaluate or test your summarization model"""
18 |
19 | import sys
20 | import time
21 | import os
22 | import tensorflow as tf
23 | import numpy as np
24 | from collections import namedtuple
25 | from data import Vocab
26 | from batcher import Batcher
27 | from model import SummarizationModel
28 | from decode import BeamSearchDecoder
29 | import util
30 | from tensorflow.python import debug as tf_debug
31 |
32 | FLAGS = tf.app.flags.FLAGS
33 |
34 | # Where to find data
35 | tf.app.flags.DEFINE_string('data_path', '',
36 | 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.')
37 | tf.app.flags.DEFINE_string('ckpt_dir', '', 'Directory that contains the ckpt files.')
38 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.')
39 |
40 | # Important settings
41 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode')
42 | tf.app.flags.DEFINE_boolean('single_pass', False,
43 | 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.')
44 |
45 | # Where to save output
46 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.')
47 | tf.app.flags.DEFINE_string('exp_name', '',
48 | 'Name for experiment. Logs will be saved in a directory with this name, under log_root.')
49 |
50 | # Hyperparameters
51 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states')
52 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings')
53 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size')
54 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)')
55 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)')
56 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.')
57 | tf.app.flags.DEFINE_integer('min_dec_steps', 35,
58 | 'Minimum sequence length of generated summary. Applies only for beam search decoding mode')
59 | tf.app.flags.DEFINE_integer('vocab_size', 50000,
60 | 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
61 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate')
62 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad')
63 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization')
64 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else')
65 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping')
66 |
67 | # Pointer-generator or baseline model
68 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.')
69 |
70 | # Coverage hyperparameters
71 | tf.app.flags.DEFINE_boolean('coverage', False,
72 | 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.')
73 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0,
74 | 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.')
75 |
76 | # Utility flags, for restoring and changing checkpoints
77 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False,
78 | 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.')
79 | tf.app.flags.DEFINE_boolean('restore_best_model', False,
80 | 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.')
81 |
82 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger
83 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)")
84 |
85 |
86 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
87 | """Calculate the running average loss via exponential decay.
88 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve.
89 |
90 | Args:
91 | loss: loss on the most recent eval step
92 | running_avg_loss: running_avg_loss so far
93 | summary_writer: FileWriter object to write for tensorboard
94 | step: training iteration step
95 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother.
96 |
97 | Returns:
98 | running_avg_loss: new running average loss
99 | """
100 | if running_avg_loss == 0: # on the first iteration just take the loss
101 | running_avg_loss = loss
102 | else:
103 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
104 | running_avg_loss = min(running_avg_loss, 12) # clip
105 | loss_sum = tf.Summary()
106 | tag_name = 'running_avg_loss/decay=%f' % (decay)
107 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
108 | summary_writer.add_summary(loss_sum, step)
109 | tf.logging.info('running_avg_loss: %f', running_avg_loss)
110 | return running_avg_loss
111 |
112 |
113 | def restore_best_model():
114 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
115 | tf.logging.info("Restoring best model for training...")
116 |
117 | # Initialize all vars in the model
118 | sess = tf.Session(config=util.get_config())
119 | print("Initializing all variables...")
120 | sess.run(tf.initialize_all_variables())
121 |
122 | # Restore the best model from eval dir
123 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
124 | print("Restoring all non-adagrad variables from best model in eval dir...")
125 | curr_ckpt = util.load_ckpt(saver, sess, "eval")
126 | print("Restored %s." % curr_ckpt)
127 |
128 | # Save this model to train dir and quit
129 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
130 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
131 | print("Saving model to %s..." % new_fname)
132 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
133 | new_saver.save(sess, new_fname)
134 | print("Saved.")
135 | exit()
136 |
137 |
138 | def convert_to_coverage_model():
139 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
140 | tf.logging.info("converting non-coverage model to coverage model..")
141 |
142 | # initialize an entire coverage model from scratch
143 | sess = tf.Session(config=util.get_config())
144 | print("initializing everything...")
145 | sess.run(tf.global_variables_initializer())
146 |
147 | # load all non-coverage weights from checkpoint
148 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
149 | print("restoring non-coverage variables...")
150 | curr_ckpt = util.load_ckpt(saver, sess, FLAGS.ckpt_dir)
151 | print("restored.")
152 |
153 | # save this model and quit
154 | new_fname = curr_ckpt + '_cov_init'
155 | print("saving model to %s..." % new_fname)
156 | new_saver = tf.train.Saver() # this one will save all variables that now exist
157 | new_saver.save(sess, new_fname)
158 | print("saved.")
159 | exit()
160 |
161 |
162 | def setup_training(model, batcher):
163 | """Does setup before starting training (run_training)"""
164 | train_dir = os.path.join(FLAGS.log_root, "train")
165 | if not os.path.exists(train_dir):
166 | os.makedirs(train_dir)
167 |
168 | model.build_graph() # build the graph
169 | if FLAGS.convert_to_coverage_model:
170 | if not FLAGS.coverage:
171 | raise ValueError(
172 | "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True"
173 | "and coverage=True")
174 | convert_to_coverage_model()
175 | if FLAGS.restore_best_model:
176 | restore_best_model()
177 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
178 |
179 | sv = tf.train.Supervisor(logdir=train_dir,
180 | is_chief=True,
181 | saver=saver,
182 | summary_op=None,
183 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs
184 | save_model_secs=60, # checkpoint every 60 secs
185 | global_step=model.global_step)
186 | summary_writer = sv.summary_writer
187 | tf.logging.info("Preparing or waiting for session...")
188 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config())
189 | tf.logging.info("Created session.")
190 | try:
191 | run_training(model, batcher, sess_context_manager, sv,
192 | summary_writer) # this is an infinite loop until interrupted
193 | except KeyboardInterrupt:
194 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
195 | sv.stop()
196 |
197 |
198 | def run_training(model, batcher, sess_context_manager, sv, summary_writer):
199 | """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
200 | tf.logging.info("starting run_training")
201 | with sess_context_manager as sess:
202 | if FLAGS.debug: # start the tensorflow debugger
203 | sess = tf_debug.LocalCLIDebugWrapperSession(sess)
204 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
205 | while True: # repeats until interrupted
206 | batch = batcher.next_batch()
207 |
208 | tf.logging.info('running training step...')
209 | t0 = time.time()
210 | results = model.run_train_step(sess, batch)
211 | t1 = time.time()
212 | tf.logging.info('seconds for training step: %.3f', t1 - t0)
213 |
214 | loss = results['loss']
215 | tf.logging.info('loss: %f', loss) # print the loss to screen
216 |
217 | if loss < 5:
218 | break
219 |
220 | if not np.isfinite(loss):
221 | raise Exception("Loss is not finite. Stopping.")
222 |
223 | if FLAGS.coverage:
224 | coverage_loss = results['coverage_loss']
225 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen
226 |
227 | # get the summaries and iteration number so we can write summaries to tensorboard
228 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
229 | train_step = results['global_step'] # we need this to update our running average loss
230 |
231 | summary_writer.add_summary(summaries, train_step) # write the summaries
232 | if train_step % 100 == 0: # flush the summary writer every so often
233 | summary_writer.flush()
234 |
235 |
236 | def run_eval(model, batcher, vocab):
237 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
238 | model.build_graph() # build the graph
239 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
240 | sess = tf.Session(config=util.get_config())
241 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
242 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
243 | summary_writer = tf.summary.FileWriter(eval_dir)
244 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
245 | best_loss = None # will hold the best loss achieved so far
246 |
247 | while True:
248 | _ = util.load_ckpt(saver, sess, FLAGS.ckpt_dir) # load a new checkpoint
249 | batch = batcher.next_batch() # get the next batch
250 |
251 | # run eval on the batch
252 | t0 = time.time()
253 | results = model.run_eval_step(sess, batch)
254 | t1 = time.time()
255 | tf.logging.info('seconds for batch: %.2f', t1 - t0)
256 |
257 | # print the loss and coverage loss to screen
258 | loss = results['loss']
259 | tf.logging.info('loss: %f', loss)
260 | if FLAGS.coverage:
261 | coverage_loss = results['coverage_loss']
262 | tf.logging.info("coverage_loss: %f", coverage_loss)
263 |
264 | # add summaries
265 | summaries = results['summaries']
266 | train_step = results['global_step']
267 | summary_writer.add_summary(summaries, train_step)
268 |
269 | # calculate running avg loss
270 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step)
271 |
272 | # If running_avg_loss is best so far, save this checkpoint (early stopping).
273 | # These checkpoints will appear as bestmodel- in the eval dir
274 | if best_loss is None or running_avg_loss < best_loss:
275 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss,
276 | bestmodel_save_path)
277 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
278 | best_loss = running_avg_loss
279 |
280 | # flush the summary writer every so often
281 | if train_step % 100 == 0:
282 | summary_writer.flush()
283 |
284 |
285 | def main(unused_argv):
286 | if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
287 | raise Exception("Problem with flags: %s" % unused_argv)
288 |
289 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
290 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))
291 |
292 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
293 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
294 | if not os.path.exists(FLAGS.log_root):
295 | if FLAGS.mode == "train":
296 | os.makedirs(FLAGS.log_root)
297 | else:
298 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))
299 |
300 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary
301 |
302 | # If in decode mode, set batch_size = beam_size
303 | # Reason: in decode mode, we decode one example at a time.
304 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
305 | if FLAGS.mode == 'decode':
306 | FLAGS.batch_size = FLAGS.beam_size
307 |
308 | # If single_pass=True, check we're in decode mode
309 | if FLAGS.single_pass and FLAGS.mode != 'decode':
310 | raise Exception("The single_pass flag should only be True in decode mode")
311 |
312 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
313 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
314 | 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt',
315 | 'pointer_gen']
316 | hps_dict = {}
317 | for key, val in FLAGS.__flags.items(): # for each flag
318 | if key in hparam_list: # if it's in the list
319 | hps_dict[key] = val # add it to the dict
320 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
321 |
322 | # Create a batcher object that will create minibatches of data
323 | # MAX: This line is commented because we only need to run decode and a batcher will be run again there. Might have to uncomment once we need to train
324 | # batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass)
325 |
326 | tf.set_random_seed(111) # a seed value for randomness
327 |
328 | if hps.mode == 'train':
329 | print("creating model...")
330 | model = SummarizationModel(hps, vocab)
331 | setup_training(model, batcher)
332 | elif hps.mode == 'eval':
333 | model = SummarizationModel(hps, vocab)
334 | run_eval(model, batcher, vocab)
335 | elif hps.mode == 'decode':
336 | decode_model_hps = hps # This will be the hyperparameters for the decoder model
337 | decode_model_hps = hps._replace(
338 | max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
339 | for line in sys.stdin: # Each line is a data file path
340 | tf.reset_default_graph()
341 | model = SummarizationModel(decode_model_hps, vocab)
342 | # Create a batcher object that will create minibatches of data
343 | batcher = Batcher(line.strip(), vocab, hps, single_pass=True)
344 | decoder = BeamSearchDecoder(model, batcher, vocab)
345 | decoder.decode()
346 | print('', flush=True)
347 | else:
348 | raise ValueError("The 'mode' flag must be one of train/eval/decode")
349 |
350 |
351 | if __name__ == '__main__':
352 | tf.app.run()
353 |
--------------------------------------------------------------------------------
/core/getpoint/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains some utility functions"""
18 |
19 | import tensorflow as tf
20 | import time
21 | import os
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | def get_config():
25 | """Returns config for tf.session"""
26 | config = tf.ConfigProto(allow_soft_placement=True)
27 | config.gpu_options.allow_growth=True
28 | return config
29 |
30 | def load_ckpt(saver, sess, ckpt_dir=""):
31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name."""
32 | while True:
33 | try:
34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None
35 | ckpt_dir = os.path.join(FLAGS.ckpt_dir, ckpt_dir)
36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
38 | saver.restore(sess, ckpt_state.model_checkpoint_path)
39 | return ckpt_state.model_checkpoint_path
40 | except:
41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
42 | time.sleep(10)
43 |
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import logging
18 | import os
19 | from pathlib import Path
20 | from subprocess import Popen, PIPE # nosec - B404:blacklist]
21 | from tempfile import NamedTemporaryFile, TemporaryDirectory
22 | from threading import Lock
23 | from flask import abort
24 |
25 | from maxfw.model import MAXModelWrapper
26 | from config import ASSET_DIR, DEFAULT_MODEL_PATH, DEFAULT_VOCAB_PATH, MODEL_NAME
27 |
28 | from .getpoint.convert import convert_to_bin
29 | from .util import process_punctuation
30 |
31 | logger = logging.getLogger()
32 |
33 |
34 | class ModelWrapper(MAXModelWrapper):
35 |
36 | MODEL_META_DATA = {
37 | 'id': 'max-text-summarizer',
38 | 'name': 'MAX Text Summarizer',
39 | 'description': '{} TensorFlow model trained on CNN/Daily Mail Data'.format(MODEL_NAME),
40 | 'type': 'Text Analysis',
41 | 'source': 'https://developer.ibm.com/exchanges/models/all/max-text-summarizer/',
42 | 'license': 'Apache V2'
43 | }
44 |
45 | _tempfile_mutex = Lock()
46 |
47 | def __init__(self, path=DEFAULT_MODEL_PATH):
48 | logger.info('Loading model from: %s...', path)
49 |
50 | self.log_dir = TemporaryDirectory()
51 | self.p_summarize = Popen(['python', 'core/getpoint/run_summarization.py', '--mode=decode', # nosec - B603
52 | '--ckpt_dir={}'.format(ASSET_DIR),
53 | '--vocab_path={}'.format(DEFAULT_VOCAB_PATH),
54 | '--log_root={}'.format(self.log_dir.name)],
55 | stdin=PIPE, stdout=PIPE)
56 |
57 | def __del__(self):
58 | self.p_summarize.stdin.close()
59 | self.log_dir.cleanup()
60 |
61 | def _pre_process(self, x):
62 | return process_punctuation(x)
63 |
64 | def _predict(self, x):
65 | if all(not c.isalpha() for c in x):
66 | abort(400, 'Input file contains no alphabetical characters.')
67 |
68 | with __class__._tempfile_mutex:
69 | # Create temporary file for inter-process communication. This
70 | # procedure must be not executed by two threads at the same time to
71 | # avoid file name conflicts.
72 | try:
73 | # Make use of tmpfs on Linux if available.
74 | directory = Path("/dev/shm/max-ts-{}".format(os.getpid())) # nosec - B108:hardcoded_tmp_directory
75 | # The following two lines may also raise IOError
76 | directory.mkdir(parents=True, exist_ok=True)
77 | bin_file = NamedTemporaryFile(
78 | prefix='generated_sample_', suffix='.bin',
79 | dir=directory.absolute())
80 | except IOError as e:
81 | logger.warning('Failed to create temporary file in RAM. '
82 | 'Fall back to disk files: %s', e)
83 | directory = Path("./assets/max-ts-{}".format(os.getpid()))
84 | directory.mkdir(parents=True, exist_ok=True)
85 | bin_file = NamedTemporaryFile(
86 | prefix='generated_sample_', suffix='.bin',
87 | dir=directory.absolute())
88 |
89 | with bin_file:
90 |
91 | bin_file_path = bin_file.name
92 |
93 | convert_to_bin(x, bin_file_path)
94 |
95 | try:
96 | self.p_summarize.stdin.write(bin_file_path.encode('utf8'))
97 | self.p_summarize.stdin.write(b'\n')
98 | self.p_summarize.stdin.flush()
99 | # One paragraph at a time under our usage.
100 | summary = self.p_summarize.stdout.readline()
101 | except (IOError, BrokenPipeError) as e:
102 | err_msg = 'Failed to communicate with the summarizer.'
103 | logger.error(err_msg + ' %s', e)
104 | abort(400, err_msg)
105 |
106 | summary = summary.decode('utf-8')
107 |
108 | if len(summary) <= len(x):
109 | return summary
110 |
111 | # Truncate the summary length to be no longer than x. Note that x already has its punctuations processed.
112 | if not summary[len(x)].isspace(): # We are truncating the middle of a word. Also remove the last word
113 | return summary[:len(x)].rsplit(maxsplit=1)[0]
114 | else:
115 | return summary[:len(x)]
116 |
--------------------------------------------------------------------------------
/core/util.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import re
18 |
19 |
20 | def process_punctuation(x):
21 | "this model does not like punctuation touching characters."
22 | return re.sub('([.,!?()])', r' \1 ', x) # https://stackoverflow.com/a/3645946/
23 |
--------------------------------------------------------------------------------
/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 |
--------------------------------------------------------------------------------
/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png
--------------------------------------------------------------------------------
/docs/swagger-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Text-Summarizer/d775ac8ad2bdc5c7054685ba2e4cdf96ffaf3aae/docs/swagger-screenshot.png
--------------------------------------------------------------------------------
/max-text-summarizer.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: max-text-summarizer
5 | spec:
6 | selector:
7 | app: max-text-summarizer
8 | ports:
9 | - port: 5000
10 | type: NodePort
11 | ---
12 | apiVersion: apps/v1
13 | kind: Deployment
14 | metadata:
15 | name: max-text-summarizer
16 | labels:
17 | app: max-text-summarizer
18 | spec:
19 | selector:
20 | matchLabels:
21 | app: max-text-summarizer
22 | replicas: 1
23 | template:
24 | metadata:
25 | labels:
26 | app: max-text-summarizer
27 | spec:
28 | containers:
29 | - name: max-text-summarizer
30 | image: quay.io/codait/max-text-summarizer:latest
31 | ports:
32 | - containerPort: 5000
33 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest==6.2.0
2 | requests==2.25.0
3 | flake8==3.8.4
4 | bandit==1.6.2
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.0.0 # Do not upgrade. We use features specifically in this version (e.g., the format of the training data). See https://github.com/becxer/pointer-generator/#about-this-code
2 | pyrouge==0.1.3
3 |
--------------------------------------------------------------------------------
/samples/README.md:
--------------------------------------------------------------------------------
1 | # Sample Details
2 |
3 | ## Test examples
4 |
5 | This directory contains test samples for the model. These test samples are part of the [CNN / Daily Mail](https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail) dataset and are licensed under the [MIT](https://opensource.org/licenses/MIT) license.
6 |
--------------------------------------------------------------------------------
/samples/sample1.json:
--------------------------------------------------------------------------------
1 | {
2 | "text": [
3 | "nick gordon 's stepfather has revealed in a new interview that he fears for the life of bobbi kristina 's troubled fiance . it has been reported that gordon , 25 , has threatened suicide and has been taking xanax since whitney houston 's daughter was found unconscious in a bathtub in late january . on wednesday , access hollywood spoke exclusively to gordon 's stepfather about his son 's state of mind , asking him if he fears for nick 's life . scroll down for video . speaking out : nick gordon 's father -lrb- left and right -rrb- gave an interview about the 25-year-old fiance of bobbi kristina brown , who reportedly had threatened suicide . access hollywood will air the complete interview thursday . ' i fear ... that if something happens to bobbi kristina , like , if she does n't pull it through , then i will fear for my son 's life , ' the stepfather stated . access hollywood will air the complete interview thursday . speaking to dr phil mcgraw gordon , 25 , said : ` i lost the most legendary singer ever and i 'm scared to lose __krissy__ . i want to let all you guys know i did everything possible in the world to protect them . ' gordon 's impassioned assertion was made during a dramatic intervention staged by dr phil last week and due to air today . weeping and wailing and at times incoherent , gordon had believed that he was to be interviewed by dr phil . gordon had told the eminent psychologist that he wanted to tell his side of the story . he has felt ` vilified ' and ` depicted as a monster ' since january 31 when bobby kristina , 22 , was found face down and unresponsive in the __bath-tub__ at the roswell , georgia home he and bobby kristina shared . breakdown : with his mother , michelle , by his side nick gordon struggles to stay coherent . troubling preview : a trailer for the interview of bobbi kristina brown 's boyfriend nick gordon by dr phil mcgraw was released sunday night . swinging from crying to flying into a rage , the disturbing 30 second video shows a very troubled young man in a lot of pain as he is questioned by dr phil . but on learning of gordon 's emotional , mental and physical state dr phil could not ` in good conscience ' go on with the interview . according to gordon 's mother , michelle , who was by her son 's side during the highly emotional intervention , ` nicholas is at a breaking point . he can not bear not being without __krissy__ , by her side . he 's dealing with it by drinking . i 've begged him to stop . ' dr phil told gordon , who admitted to having taken xanax and to have drunk heavily prior to the meeting : ` nick , you 're out of control . you 've threatened suicide . you deserved to get some help because if you do n't , you know you 're going to wind up dead . ' he added : ` you 've got to get yourself cleaned up . you got to man up and straighten up . ' yesterday dailymail.com revealed that , barely 12 hours before the dramatic encounter with dr phil , gordon was so drunk and high he was unable to walk . the disturbing scenes captured on video showed the true extent of a deterioration described by dr phil as ` exponential . ' in the show , due to broadcast wednesday , a weeping and wailing gordon admitted that he has twice tried to kill himself and confessed : ` i 'm so sorry for everything . ' asked if he still intended to kill himself he said : ` if anything happens to __krissi__ i will . ' he said : ` my pain is horrible . my heart hurts . i have panic attacks . ' in recent weeks gordon has twice overdosed on a mixture of xanax , alcohol and prescription sleeping pills . gordon had agreed to meet with dr phil believing that he was there to be interviewed . according to dr phil : ` he felt like he was being vilified and presented as some sort of monster . ' instead , on learning of gordon 's rapidly deteriorating mental , emotional and physical state , the eminent psychologist staged an intervention and he is now in rehab . dr phil stated : ' i do n't think he has any chance of turning this round . left to his own devices he will be dead within the week . ' gordon 's mother , michelle , was by her son 's side as he __alternated__ between compliant and aggressive -- at one point threatening to attack camera men as they filmed . she described her son as ` at breaking point . ' she said : ` he can not take too much more of not being able to see __kriss__ . he blames himself . he 's torn up and carrying guilt . ` he 's dealing with it by drinking . i 've begged him to stop . i 've tried to help him but he ca n't let go of the guilt . ' leaning towards gordon 's mother , dr phil said : ` you and i have one mission with one possible outcome and that 's for him to agree to go to rehab to deal with his depression , his guilt ... and to get clean and sober . ' he added : ` his life absolutely hangs in the balance . ' questions still rage regarding just what happened in the early hours of january 31st that led to bobbi kristina , 22 , ending up face down in her tub . just two days ago bobbi kristina 's aunt , __leolah__ brown , made a series of explosive facebook posts in which she alleged that the family had ` strong __evidene__ of foul play ' relating to gordon 's role in events . she posted : ` nick gordon is very disrespectful and inconsiderate ! especially to my family . moreover , he has done things to my niece that i never thought he had in him to do ! ' __leolah__ made her claims in response to being invited onto dr phil 's show . tough to watch : the young man sobs and shakes at times through the trailer . in her message she wrote : ` with all due respect , nick gordon is under investigation for the attempted murder of my niece ... . we have strong evidence of foul play . ' marks were found on her face and arms , marks that gordon has explained as the result of cpr which he administered to her for 15 minutes before emergency services arrived . and speaking to dr phil , gordon insisted : ' i did everything possible in the world to protect them -lsb- whitney and bobbi kristina -rsb- . ' railing against the decision by bobbi 's father , bobby brown , to ban him from his fiancée 's bedside at atlanta 's emory hospital , gordon said : ` my name will be the first she calls . ' according to his mother , michelle , gordon can not forgive himself for his ` failure ' to revive bobbi kristina . his guilt is compounded by the eerily similar situation in which he found himself almost exactly three years earlier . ' i hate bobby brown ! ' his tears quickly turn to anger when he brings up his girlfriend 's father bobby brown , who has blocked him from seeing the 22-year-old in the hospital . ` are you drinking ? ' the emotional man 's anger only increases as dr. phil asks about nick 's sobriety . refuses to go : the 25-year-old aspiring rapper storms out when told he needs help . resigned : ultimately , gordon concedes that he needs help and leaves for a rehab facility . then , on 11 february 2012 , it was gordon who tried to resuscitate whitney when she was found unresponsive in her __bath-tub__ just hours before she was due to attend a grammy awards party . speaking to dr phil , michelle said : ` nicholas just continually expresses how much he has failed whitney . ` he administered cpr -lsb- to whitney -rsb- and he called me when he was standing over her body . he could n't understand why he could n't revive her . he said , `` mommy why ? i could n't get the air into her lungs . '' gordon has now entered a rehab facility in atlanta . meanwhile his fiancée continues to fight for her life in a medically-induced coma . it is now six weeks since bobbi kristina -- __krissi__ as she was known to friends - was found face down and unresponsive in the __bath-tub__ of the home she shared with gordon in roswell , georgia . unlike whitney 's death that was ruled accidental , police are treating bobbi kristina 's near drowning and injuries allegedly sustained by the 22-year-old as a criminal investigation . while bobbi kristina fights for her life a troubling picture of her relationship with nick , of drug use and domestic turbulence in the weeks leading up to the incident , has emerged . in an interview with the sun a friend of the couple , steven __stepho__ , claimed they were using various drugs daily including heroin , xanax , pot and heroin substitute __roxicodone__ . sounding eerily similar to bobbi 's mother and father 's __drug-dependent__ relationship , the friend said the pair used whatever they could get their hands on . ` bobbi and nick would spend a lot on drugs every day , it just depended on how much money they had . it was n't unusual for them to spend $ 1,000 a day on drugs . ` there were times when it got really bad - they would be completely passed out for hours , just lying there on the bed . there were times when she would be so knocked out she would burn herself with a cigarette and not even notice . she was always covered in cigarette burns . ' their relationship was also very volatile . ` when whitney died nick was left with nothing , so he knew he had to control __krissi__ to get access to the money . she 'd do whatever he told her . ` he was very manipulative and would even use the drugs to control her . they would argue a lot and there were times when he would be violent with her and push her around . ' according to __stepho__ : ` but __krissi__ really loved him because he was there to fill the gap left by her mother . she was not close to her father and did not have anyone else close to her . nick knew this and took advantage of it . ' troubled relationship : a friend of the couple , steven __stepho__ , claimed the pair were using various drugs daily including heroin , xanax , pot and heroin substitute __roxicodone__ ."
4 | ]
5 | }
6 |
--------------------------------------------------------------------------------
/samples/sample2.json:
--------------------------------------------------------------------------------
1 | {
2 | "text": [
3 | "The Korean Air executive who kicked up a fuss over a bag of nuts will resign from her remaining posts with the airline , the company chairman -- who is also her father -- said Friday . The executive , Heather Cho , found herself at the center of a media storm after she ordered that a plane turn back to the gate and that a flight attendant be removed -- all because she was served nuts in a bag instead of on a plate in first class . Although her role put her in charge of in-flight service , she was only a passenger on the flight and was not flying in an official capacity . The incident , which took place last week at New York 's JFK airport , stirred anger among the South Korean public over Cho 's behavior . Cho , whose Korean name is Cho Hyun-ah , resigned Tuesday from the airline 's catering and in-flight sales business , and from its cabin service and hotel business divisions , the company said . But the 40-year-old kept her title as a vice president of the national carrier , according to company spokesman . That 's going to change , her father , Cho Yang-ho , said Friday as he made a public apology for what happened . She will be resigning from the vice president job and positions held in affiliate companies , he said . Asked by reporters how the incident could have happened , the company chairman blamed himself , saying he 'd raised her badly . ` Outburst of anger ' A local English-language newspaper , The Korea Times , said her behavior has deepened public resentment of South Korea 's large family-owned corporations , known as chaebol . `` Through her outburst of anger , she not only caused inconvenience to KAL passengers , but also to those on other flights , '' the newspaper said in an editorial Tuesday . The most annoying type of airline passenger is ... South Korean authorities are now investigating the incident , which occurred on a flight due to take off for Incheon International Airport near Seoul . Cho arrived at the Ministry of Land , Infrastructure and Transport on Friday as part of the investigation , according to local TV coverage . She spoke in such a low voice that it was inaudible from the TV footage . ` An excessive act ' Korean Air apologized for any inconvenience to those on the flight and said there had been no safety issues involved . The plane arrived at its destination 11 minutes behind schedule , according to the South Korean news agency Yonhap . `` Even though it was not an emergency situation , backing up the plane to order an employee to deplane was an excessive act , '' the airline said earlier this week . `` We will re-educate all our employees to make sure service within the plane meets high standards . '' The airline also issued an apology on Heather Cho 's behalf , Yonhap reported , in which she asked for forgiveness and said she would take `` full responsibility '' for the incident . According to her biography on the website of Nanyang Technological University , Heather Cho joined the airline in 1999 and has since been `` actively involved in establishing a new corporate identity for Korean Air . '' She studied at Cornell University and the University of Southern California . @highlight Airline chairman : Heather Cho will be resigning from all posts with the company @highlight Korean Air said previously she had resigned from some roles but was keeping her VP title @highlight The chairman , who is also her father , blames himself , saying he raised her badly @highlight Cho ordered a plane back to the gate after a flight attendant served nuts in a bag"
4 | ]
5 | }
6 |
--------------------------------------------------------------------------------
/samples/sample3.json:
--------------------------------------------------------------------------------
1 | {
2 | "text": [
3 | "Attorney General Eric Holder urged other countries to enact new criminal laws to help prevent possible terrorist attacks from returning Syrian fighters . In a speech in Oslo , Norway , on Tuesday , Holder highlighted the threat from domestic extremists such as a lone terrorist who three years ago bombed government buildings in the Norwegian capital and gunned down people at a youth camp , killing 77 people in all . Holder said the United States , along with Norway and other European countries , face similar danger from `` violent extremists fighting today in Syria , Iraq or other locations '' who `` may seek to commit acts of terror tomorrow in our countries as well . '' U.S. intelligence estimates that nearly 7,000 foreign fighters have traveled to Syria , including dozens from the United States . The issue of Syrian fighters from Western countries is dominating Holder 's trip this week to Europe , which includes a meeting in London of attorneys general from the U.S. , UK , Canada , Australia and New Zealand . Open visa access among European countries and the U.S. means `` the problem of fighters in Syria returning to any of our countries is a problem for all of our countries , '' Holder said in his speech , excerpts of which were provided in advance by the Justice Department . `` The Syrian conflict has turned that region into a cradle of violent extremism , '' he said . `` But the world can not simply sit back and let it become a training ground from which our nationals can return and launch attacks . '' To combat the threat , Holder is calling on countries to pass laws to criminalize the preparatory steps that suspects often take before an attack , and to allow police to conduct undercover investigations . The U.S. has a law that makes it a crime to provide `` material support '' to terrorists , including supplying money or weapons , or helping to plot an attack . Similar laws are now on the books in Norway and France . Holder also cited the FBI 's success in using undercover sting operations , which have drawn controversy in the U.S. but have been successful in prosecuting dozens of suspects who had admitted to plans to commit terrorism . Many countries do n't allow such operations . But Holder said the use of such operations could help countries thwart attacks . He said the U.S. has already used these tools to carry out prosecutions of people who sought to travel to join the fight in Syria . `` These operations are conducted with extraordinary care and precision , ensuring that law enforcement officials are accountable for the steps they take and that suspects are neither entrapped nor denied legal protections , '' Holder said . He also called on countries to share information with Interpol and one another about their nationals who try to travel to Syria to fight and those who return . And he urged countries to come up with counter-radicalization programs that try to reach communities where young people may be exposed to extremists . `` We must seek to stop individuals from becoming radicalized in the first place by putting in place strong programs to counter violent extremism in its earliest stages , '' Holder said . @highlight Eric Holder urges countries to enact new criminal laws that clamp down on suspects @highlight He cites `` violent extremists fighting today in Syria , Iraq or other locations '' @highlight U.S. intelligence estimates that nearly 7,000 foreign fighters have traveled to Syria @highlight Holder urges nations to share information about nationals who try to travel to Syria to fight"
4 | ]
5 | }
6 |
--------------------------------------------------------------------------------
/sha512sums.txt:
--------------------------------------------------------------------------------
1 | 3fbf53cc7c32402a88d4b33762dc84e9ed1036a473cbf64d426805f421686e4c97db31cdc8cba4e043f54b5f1610dbac2529d49952277a562759dc748ab2cbf6 assets/checkpoint
2 | 07aaa28da583bdae86c3557f6563653dccc9b579bbeaf3ab7e907b764170579fe1718aeb21e9281dcbf85905ee21e10c7422cabfcdd85ea9a750bb8bfe5bdf9c assets/graph.pbtxt
3 | 732c621da1f2d5fe26c00cbf224a82e99302f55db98c2359bd10c07e14022b609440afbeba089451d4f766f287d35abaced5e0f7673540afff84f510c3d28e88 assets/model-238410.data-00000-of-00001
4 | b88a63d72d4a35dba130cc2b0b7094ab7fae2da058892b68503dab821c5eca6c4ab89c71481a71a7e497e686069396ba57041dd1903a89d29e2f0cd91f9236ff assets/model-238410.index
5 | 9d62308623b4833a519d0b547bbda6387cb6ce9bfaf32f6233d34b79f13788d798d68603eae48f3d56c0042584e917c7378027203aabbb470289771aeeaae548 assets/model-238410.meta
6 | cb5558e40557be772442c4079575ac41c69c991a238fd54c1de44b9bff47823279766aadb5f5107c260627a25129c7628cda84fec6f775fb68285047a599977b assets/projector_config.pbtxt
7 | 5976b96ab0d2a5ea306e30f2e3bbbc866440a0fa0a5e13e040a1c43e223e6c0dc4fb7d7284822759a9c34b651b8fa61087b5ea5add0ebfdd61e46d2107597269 assets/vocab
8 | 4b404c4f7de95547b0d82bc150f705ed61878fa34ac2941a5559e99f59bf2e231d20b71e09a3aadb2586992dce5ff5043db3eff24dec1655eaf4862685dc483e assets/vocab_metadata.tsv
9 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import glob
18 | import json
19 | import pytest
20 | import requests
21 |
22 | from core.util import process_punctuation
23 |
24 |
25 | MODEL_PREDICT_ENDPOINT = 'http://localhost:5000/model/predict'
26 |
27 |
28 | def test_swagger():
29 |
30 | model_endpoint = 'http://localhost:5000/swagger.json'
31 |
32 | r = requests.get(url=model_endpoint)
33 | assert r.status_code == 200
34 | assert r.headers['Content-Type'] == 'application/json'
35 |
36 | json = r.json()
37 | assert 'swagger' in json
38 | assert json.get('info').get('title') == 'MAX Text Summarizer'
39 | assert json.get('info').get('description') == 'Generate a summarized description of a body of text.'
40 |
41 |
42 | def test_metadata():
43 |
44 | model_endpoint = 'http://localhost:5000/model/metadata'
45 |
46 | r = requests.get(url=model_endpoint)
47 | assert r.status_code == 200
48 |
49 | metadata = r.json()
50 | assert metadata['id'] == 'max-text-summarizer'
51 | assert metadata['name'] == 'MAX Text Summarizer'
52 | assert metadata['description'] == 'get_to_the_point TensorFlow model trained on CNN/Daily Mail Data'
53 | assert metadata['license'] == 'Apache V2'
54 |
55 |
56 | def test_predict_valid():
57 | "Test prediction for valid input."
58 | short_text = 'Why not summarize?'
59 | with open('samples/sample1.json') as f:
60 | long_text = json.load(f)['text'][0]
61 |
62 | json_data = {
63 | 'text': [short_text, long_text]
64 | }
65 |
66 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data)
67 |
68 | assert r.status_code == 200
69 | response = r.json()
70 | assert response['status'] == 'ok'
71 | assert len(response['summary_text']) == 2
72 | for i, text in enumerate(json_data['text']):
73 | assert len(response['summary_text'][i]) <= len(process_punctuation(text))
74 |
75 |
76 | def test_predict_sample():
77 | "Test prediction for sample inputs."
78 | for sample in glob.iglob('samples/*.json'):
79 | with open(sample) as f:
80 | json_data = json.load(f)
81 |
82 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data)
83 |
84 | assert r.status_code == 200
85 | response = r.json()
86 | assert response['status'] == 'ok'
87 | assert len(response['summary_text']) == len(json_data['text'])
88 | for i, text in enumerate(json_data['text']):
89 | assert len(response['summary_text'][i]) <= len(process_punctuation(text))
90 |
91 |
92 | def test_predict_invalid_input_no_string():
93 | "Test invalid input: no string."
94 | json_data = {}
95 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data)
96 | assert r.status_code == 400
97 |
98 |
99 | def test_predict_invalid_empty_string():
100 | "Test invalid input: empty and blank strings."
101 | for s in ('', ' \t\f\r\n'):
102 | for text in ([s], ['some text', s]):
103 | json_data = {'text': text}
104 | r = requests.post(url=MODEL_PREDICT_ENDPOINT, json=json_data)
105 | assert r.status_code == 400
106 |
107 |
108 | if __name__ == '__main__':
109 | pytest.main([__file__])
110 |
--------------------------------------------------------------------------------