├── .bandit
├── .dockerignore
├── .gitignore
├── .travis.yml
├── Dockerfile
├── LICENSE
├── README.md
├── api
├── __init__.py
├── metadata.py
└── predict.py
├── app.py
├── assets
├── README.md
└── word_counts.txt
├── config.py
├── core
├── __init__.py
├── configuration.py
├── inference_utils
│ ├── BUILD
│ ├── __init__.py
│ ├── caption_generator.py
│ ├── inference_wrapper_base.py
│ └── vocabulary.py
├── inference_wrapper.py
├── model.py
├── ops
│ ├── BUILD
│ ├── __init__.py
│ ├── image_embedding.py
│ ├── image_processing.py
│ └── inputs.py
└── show_and_tell_model.py
├── docs
├── deploy-max-to-ibm-cloud-with-kubernetes-button.png
└── swagger-screenshot.png
├── max-image-caption-generator.yaml
├── requirements-test.txt
├── requirements.txt
├── samples
├── README.md
├── plane.jpg
├── soccer.jpg
└── surfing.jpg
├── sha512sums.txt
└── tests
├── surfing.jpg
├── surfing.png
└── test.py
/.bandit:
--------------------------------------------------------------------------------
1 | [bandit]
2 | exclude: /tests,/training
3 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyCharm
29 | .idea/
30 | *.iml
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | .venv
91 | venv/
92 | ENV/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 | /.pytest_cache/
107 | .gitignore
108 | .git/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyCharm
29 | .idea/
30 | *.iml
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | .venv
91 | venv/
92 | ENV/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 | /.pytest_cache/
107 |
108 | # Ignore Mac DS_Store files
109 | .DS_Store
110 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | language: python
18 |
19 | python:
20 | - 3.6
21 |
22 | services:
23 | - docker
24 |
25 | install:
26 | - docker build -t max-image-caption-generator .
27 | - docker run -it --rm -d -p 5000:5000 max-image-caption-generator
28 | - pip install -r requirements-test.txt
29 |
30 | before_script:
31 | - flake8 . --max-line-length=127
32 | - bandit -r .
33 | - sleep 30
34 |
35 | script:
36 | - pytest tests/test.py
37 |
38 | ### Temporarily disable downstream tests ###
39 | #jobs:
40 | # include:
41 | # - stage: test
42 | # script:
43 | # - pytest tests/test.py
44 | # - stage: trigger downstream
45 | # jdk: oraclejdk8
46 | # script: |
47 | # echo "TRAVIS_BRANCH=$TRAVIS_BRANCH TRAVIS_PULL_REQUEST=$TRAVIS_PULL_REQUEST"
48 | # if [[ ($TRAVIS_BRANCH == master) &&
49 | # ($TRAVIS_PULL_REQUEST == false) ]] ; then
50 | # curl -LO --retry 3 https://raw.github.com/mernst/plume-lib/master/bin/trigger-travis.sh
51 | # sh trigger-travis.sh IBM MAX-Image-Caption-Generator-Web-App $TRAVIS_ACCESS_TOKEN
52 | # fi
53 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | FROM quay.io/codait/max-base:v1.5.0
18 |
19 | ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0
20 | ARG model_file=assets.tar.gz
21 |
22 | RUN wget -nv --show-progress --progress=bar:force:noscroll ${model_bucket}/${model_file} --output-document=assets/${model_file} && \
23 | tar -x -C assets/ -f assets/${model_file} -v && rm assets/${model_file}
24 |
25 | COPY requirements.txt .
26 | RUN pip install -r requirements.txt
27 |
28 | COPY . .
29 |
30 | # check file integrity
31 | RUN sha512sum -c sha512sums.txt
32 |
33 | EXPOSE 5000
34 |
35 | CMD python app.py
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/IBM/MAX-Image-Caption-Generator) [](http://max-image-caption-generator.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud)
2 |
3 | [
](http://ibm.biz/max-to-ibm-cloud-tutorial)
4 |
5 | # IBM Developer Model Asset Exchange: Image Caption Generator
6 |
7 | This repository contains code to instantiate and deploy an image caption generation model. This model generates captions from a fixed vocabulary that describe the contents of images in the [COCO Dataset](http://cocodataset.org/#home). The model consists of an _encoder_ model - a deep convolutional net using the Inception-v3 architecture trained on [ImageNet-2012 data](http://www.image-net.org/challenges/LSVRC/2012/) - and a _decoder_ model - an LSTM network that is trained conditioned on the encoding from the image _encoder_ model. The input to the model is an image, and the output is a sentence describing the image content.
8 |
9 | The model is based on the [Show and Tell Image Caption Generator Model](https://github.com/tensorflow/models/tree/master/research/im2txt). The checkpoint files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0/assets.tar.gz). The code in this repository deploys the model as a web service in a Docker container. This repository was developed as part of the [IBM Code Model Asset Exchange](https://developer.ibm.com/code/exchanges/models/).
10 |
11 | ## Model Metadata
12 | | Domain | Application | Industry | Framework | Training Data | Input Data Format |
13 | | ------------- | -------- | -------- | --------- | --------- | -------------- |
14 | | Vision | Image Caption Generator | General | TensorFlow | [COCO](http://cocodataset.org/#home) | Images |
15 |
16 | ## References
17 | * _O. Vinyals, A. Toshev, S. Bengio, D. Erhan._, ["Show and Tell: Lessons learned from the 2015 MSCOCO Image Captioning Challenge"](https://doi.org/10.1109/TPAMI.2016.2587640), IEEE transactions on Pattern Analysis and Machine Intelligence, 2017.
18 | * [im2txt TensorFlow Model GitHub Page](https://github.com/tensorflow/models/tree/master/research/im2txt)
19 | * [COCO Dataset Project Page](http://cocodataset.org/#home)
20 |
21 | ## Licenses
22 |
23 | | Component | License | Link |
24 | | ------------- | -------- | -------- |
25 | | This repository | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](LICENSE) |
26 | | Model Weights | [MIT](https://opensource.org/licenses/MIT) | [Pretrained Show and Tell Model](https://github.com/KranthiGV/Pretrained-Show-and-Tell-model) |
27 | | Model Code (3rd party) | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [im2txt](https://github.com/tensorflow/models/tree/master/research/im2txt) |
28 | | Test assets | Various | [Sample README](samples/README.md) |
29 |
30 | ## Pre-requisites:
31 |
32 | * `docker`: The [Docker](https://www.docker.com/) command-line interface. Follow the [installation instructions](https://docs.docker.com/install/) for your system.
33 | * The minimum recommended resources for this model is 2GB Memory and 2 CPUs.
34 | * If you are on x86-64/AMD64, your CPU must support [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) at the minimum.
35 |
36 | # Deployment options
37 |
38 | * [Deploy from Quay](#deploy-from-quay)
39 | * [Deploy on Red Hat OpenShift](#deploy-on-red-hat-openshift)
40 | * [Deploy on Kubernetes](#deploy-on-kubernetes)
41 | * [Run Locally](#run-locally)
42 |
43 | ## Deploy from Quay
44 |
45 | To run the docker image, which automatically starts the model serving API, run:
46 |
47 | ```
48 | $ docker run -it -p 5000:5000 quay.io/codait/max-image-caption-generator
49 | ```
50 |
51 | This will pull a pre-built image from the Quay.io container registry (or use an existing image if already cached locally) and run it.
52 | If you'd rather checkout and build the model locally you can follow the [run locally](#run-locally) steps below.
53 |
54 | ## Deploy on Red Hat OpenShift
55 |
56 | You can deploy the model-serving microservice on Red Hat OpenShift by following the instructions for the OpenShift web console or the OpenShift Container Platform CLI [in this tutorial](https://developer.ibm.com/tutorials/deploy-a-model-asset-exchange-microservice-on-red-hat-openshift/), specifying `quay.io/codait/max-image-caption-generator` as the image name.
57 |
58 | ## Deploy on Kubernetes
59 |
60 | You can also deploy the model on Kubernetes using the latest docker image on Quay.
61 |
62 | On your Kubernetes cluster, run the following commands:
63 |
64 | ```
65 | $ kubectl apply -f https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/master/max-image-caption-generator.yaml
66 | ```
67 |
68 | The model will be available internally at port `5000`, but can also be accessed externally through the `NodePort`.
69 |
70 | A more elaborate tutorial on how to deploy this MAX model to production on [IBM Cloud](https://ibm.biz/Bdz2XM) can be found [here](http://ibm.biz/max-to-ibm-cloud-tutorial).
71 |
72 | ## Run Locally
73 |
74 | 1. [Build the Model](#1-build-the-model)
75 | 2. [Deploy the Model](#2-deploy-the-model)
76 | 3. [Use the Model](#3-use-the-model)
77 | 4. [Development](#4-development)
78 | 5. [Cleanup](#5-cleanup)
79 |
80 | ### 1. Build the Model
81 |
82 | Clone this repository locally. In a terminal, run the following command:
83 |
84 | ```
85 | $ git clone https://github.com/IBM/MAX-Image-Caption-Generator.git
86 | ```
87 |
88 | Change directory into the repository base folder:
89 |
90 | ```
91 | $ cd MAX-Image-Caption-Generator
92 | ```
93 |
94 | To build the docker image locally, run:
95 |
96 | ```
97 | $ docker build -t max-image-caption-generator .
98 | ```
99 |
100 | All required model assets will be downloaded during the build process. _Note_ that currently this docker image is CPU only (we will add support for GPU images later).
101 |
102 | ### 2. Deploy the Model
103 |
104 | To run the docker image, which automatically starts the model serving API, run:
105 |
106 | ```
107 | $ docker run -it -p 5000:5000 max-image-caption-generator
108 | ```
109 |
110 | ### 3. Use the Model
111 |
112 | The API server automatically generates an interactive Swagger documentation page. Go to `http://localhost:5000` to load it. From there you can explore the API and also create test requests.
113 |
114 | Use the `model/predict` endpoint to load a test file and get captions for the image from the API.
115 |
116 | 
117 |
118 | You can also test it on the command line, for example:
119 |
120 | ```
121 | $ curl -F "image=@samples/surfing.jpg" -X POST http://localhost:5000/model/predict
122 | ```
123 |
124 | ```json
125 | {
126 | "status": "ok",
127 | "predictions": [
128 | {
129 | "index": "0",
130 | "caption": "a man riding a wave on top of a surfboard .",
131 | "probability": 0.038827644239537
132 | },
133 | {
134 | "index": "1",
135 | "caption": "a person riding a surf board on a wave",
136 | "probability": 0.017933410519265
137 | },
138 | {
139 | "index": "2",
140 | "caption": "a man riding a wave on a surfboard in the ocean .",
141 | "probability": 0.0056628732021868
142 | }
143 | ]
144 | }
145 | ```
146 |
147 | ### 4. Development
148 |
149 | To run the Flask API app in debug mode, edit `config.py` to set `DEBUG = True` under the application settings. You will then need to rebuild the docker image (see [step 1](#1-build-the-model)).
150 |
151 | ### 5. Cleanup
152 |
153 | To stop the Docker container, type `CTRL` + `C` in your terminal.
154 |
155 | ## Links
156 |
157 | * [Image Caption Generator Web App](https://developer.ibm.com/patterns/create-a-web-app-to-interact-with-machine-learning-generated-image-captions): A reference application created by the IBM CODAIT team that uses the Image Caption Generator
158 |
159 | ## Resources and Contributions
160 |
161 | If you are interested in contributing to the Model Asset Exchange project or have any queries, please follow the instructions [here](https://github.com/CODAIT/max-central-repo).
162 |
--------------------------------------------------------------------------------
/api/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .metadata import ModelMetadataAPI # noqa
18 | from .predict import ModelPredictAPI # noqa
19 |
--------------------------------------------------------------------------------
/api/metadata.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from config import MODEL_META_DATA
18 | from maxfw.core import MAX_API, MetadataAPI, METADATA_SCHEMA
19 |
20 |
21 | class ModelMetadataAPI(MetadataAPI):
22 |
23 | @MAX_API.marshal_with(METADATA_SCHEMA)
24 | def get(self):
25 | """Return the metadata associated with the model"""
26 | return MODEL_META_DATA
27 |
--------------------------------------------------------------------------------
/api/predict.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from maxfw.core import MAX_API, PredictAPI
18 | from core.model import ModelWrapper
19 |
20 | from flask import abort
21 | from flask_restx import fields
22 | from werkzeug.datastructures import FileStorage
23 |
24 |
25 | # Set up parser for image input data
26 | image_parser = MAX_API.parser()
27 | image_parser.add_argument('image', type=FileStorage, location='files', required=True, help="An image file")
28 |
29 | label_prediction = MAX_API.model('LabelPrediction', {
30 | 'index': fields.String(required=False, description='Labels ranked by highest probability'),
31 | 'caption': fields.String(required=True, description='Caption generated by image'),
32 | 'probability': fields.Float(required=True, description="Probability of the caption")
33 | })
34 |
35 | predict_response = MAX_API.model('ModelPredictResponse', {
36 | 'status': fields.String(required=True, description='Response status message'),
37 | 'predictions': fields.List(fields.Nested(label_prediction), description='Predicted captions and probabilities')
38 | })
39 |
40 |
41 | class ModelPredictAPI(PredictAPI):
42 |
43 | model_wrapper = ModelWrapper()
44 |
45 | @MAX_API.doc('predict')
46 | @MAX_API.expect(image_parser)
47 | @MAX_API.marshal_with(predict_response)
48 | def post(self):
49 | """Make a prediction given input data"""
50 |
51 | result = {'status': 'error'}
52 | args = image_parser.parse_args()
53 | if not args['image'].mimetype.endswith(('jpg', 'jpeg', 'png')):
54 | abort(400, 'Invalid file type/extension. Please provide an image in JPEG or PNG format.')
55 | image_data = args['image'].read()
56 |
57 | preds = self.model_wrapper.predict(image_data)
58 |
59 | label_preds = [{'index': p[0], 'caption': p[1], 'probability': p[2]} for p in [x for x in preds]]
60 | result['predictions'] = label_preds
61 | result['status'] = 'ok'
62 |
63 | return result
64 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from maxfw.core import MAXApp
18 | from api import ModelMetadataAPI, ModelPredictAPI
19 | from config import API_TITLE, API_DESC, API_VERSION
20 |
21 | max_app = MAXApp(API_TITLE, API_DESC, API_VERSION)
22 | max_app.add_api(ModelMetadataAPI, '/metadata')
23 | max_app.add_api(ModelPredictAPI, '/predict')
24 | max_app.run()
25 |
--------------------------------------------------------------------------------
/assets/README.md:
--------------------------------------------------------------------------------
1 | # Asset Details
2 |
3 | ## Model files
4 |
5 | Model files are from the [Pretrained Show and Tell Model](https://github.com/KranthiGV/Pretrained-Show-and-Tell-model), where they are available under a [MIT License](https://opensource.org/licenses/MIT).
6 |
7 | _Note_ the model files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0/assets.tar.gz).
8 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # Application settings
18 |
19 | # Flask settings
20 | DEBUG = False
21 |
22 | # Flask-restplus settings
23 | RESTPLUS_MASK_SWAGGER = False
24 | SWAGGER_UI_DOC_EXPANSION = 'none'
25 |
26 | # API metadata
27 | API_TITLE = 'MAX Image Caption Generator'
28 | API_DESC = 'Generate captions that describe the contents of images.'
29 | API_VERSION = '1.1.0'
30 |
31 | # default model
32 | MODEL_NAME = 'im2txt'
33 | DEFAULT_MODEL_PATH = 'assets/checkpoint/model2.ckpt-2000000'
34 | VOCAB_FILE = './assets/word_counts.txt'
35 | # for image models, may not be required
36 | MODEL_INPUT_IMG_SIZE = (299, 299)
37 | MODEL_LICENSE = 'Apache 2.0'
38 |
39 | MODEL_META_DATA = {
40 | 'id': API_TITLE.lower().replace(' ', '-'),
41 | 'name': API_TITLE,
42 | 'description': '{} TensorFlow model trained on MSCOCO'.format(MODEL_NAME),
43 | 'type': 'Image-to-Text Translation',
44 | 'license': MODEL_LICENSE,
45 | 'source': 'https://developer.ibm.com/exchanges/models/all/max-image-caption-generator/'
46 | }
47 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/core/configuration.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Image-to-text model and training configurations."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | class ModelConfig(object):
24 | """Wrapper class for model hyperparameters."""
25 |
26 | def __init__(self):
27 | """Sets the default model hyperparameters."""
28 | # File pattern of sharded TFRecord file containing SequenceExample protos.
29 | # Must be provided in training and evaluation modes.
30 | self.input_file_pattern = None
31 |
32 | # Image format ("jpeg" or "png").
33 | self.image_format = "jpeg"
34 |
35 | # Approximate number of values per input shard. Used to ensure sufficient
36 | # mixing between shards in training.
37 | self.values_per_input_shard = 2300
38 | # Minimum number of shards to keep in the input queue.
39 | self.input_queue_capacity_factor = 2
40 | # Number of threads for prefetching SequenceExample protos.
41 | self.num_input_reader_threads = 1
42 |
43 | # Name of the SequenceExample context feature containing image data.
44 | self.image_feature_name = "image/data"
45 | # Name of the SequenceExample feature list containing integer captions.
46 | self.caption_feature_name = "image/caption_ids"
47 |
48 | # Number of unique words in the vocab (plus 1, for ).
49 | # The default value is larger than the expected actual vocab size to allow
50 | # for differences between tokenizer versions used in preprocessing. There is
51 | # no harm in using a value greater than the actual vocab size, but using a
52 | # value less than the actual vocab size will result in an error.
53 | self.vocab_size = 12000
54 |
55 | # Number of threads for image preprocessing. Should be a multiple of 2.
56 | self.num_preprocess_threads = 4
57 |
58 | # Batch size.
59 | self.batch_size = 32
60 |
61 | # File containing an Inception v3 checkpoint to initialize the variables
62 | # of the Inception model. Must be provided when starting training for the
63 | # first time.
64 | self.inception_checkpoint_file = None
65 |
66 | # Dimensions of Inception v3 input images.
67 | self.image_height = 299
68 | self.image_width = 299
69 |
70 | # Scale used to initialize model variables.
71 | self.initializer_scale = 0.08
72 |
73 | # LSTM input and output dimensionality, respectively.
74 | self.embedding_size = 512
75 | self.num_lstm_units = 512
76 |
77 | # If < 1.0, the dropout keep probability applied to LSTM variables.
78 | self.lstm_dropout_keep_prob = 0.7
79 |
80 |
81 | class TrainingConfig(object):
82 | """Wrapper class for training hyperparameters."""
83 |
84 | def __init__(self):
85 | """Sets the default training hyperparameters."""
86 | # Number of examples per epoch of training data.
87 | self.num_examples_per_epoch = 586363
88 |
89 | # Optimizer for training the model.
90 | self.optimizer = "SGD"
91 |
92 | # Learning rate for the initial phase of training.
93 | self.initial_learning_rate = 2.0
94 | self.learning_rate_decay_factor = 0.5
95 | self.num_epochs_per_decay = 8.0
96 |
97 | # Learning rate when fine tuning the Inception v3 parameters.
98 | self.train_inception_learning_rate = 0.0005
99 |
100 | # If not None, clip gradients to this value.
101 | self.clip_gradients = 5.0
102 |
103 | # How many model checkpoints to keep.
104 | self.max_checkpoints_to_keep = 5
105 |
--------------------------------------------------------------------------------
/core/inference_utils/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//im2txt:internal"])
2 |
3 | licenses(["notice"]) # Apache 2.0
4 |
5 | exports_files(["LICENSE"])
6 |
7 | py_library(
8 | name = "inference_wrapper_base",
9 | srcs = ["inference_wrapper_base.py"],
10 | srcs_version = "PY2AND3",
11 | )
12 |
13 | py_library(
14 | name = "vocabulary",
15 | srcs = ["vocabulary.py"],
16 | srcs_version = "PY2AND3",
17 | )
18 |
19 | py_library(
20 | name = "caption_generator",
21 | srcs = ["caption_generator.py"],
22 | srcs_version = "PY2AND3",
23 | )
24 |
25 | py_test(
26 | name = "caption_generator_test",
27 | srcs = ["caption_generator_test.py"],
28 | deps = [
29 | ":caption_generator",
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/core/inference_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/core/inference_utils/__init__.py
--------------------------------------------------------------------------------
/core/inference_utils/caption_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Class for generating captions from an image-to-text model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import heapq
22 | import math
23 |
24 | import numpy as np
25 |
26 |
27 | class Caption(object):
28 | """Represents a complete or partial caption."""
29 |
30 | def __init__(self, sentence, state, logprob, score, metadata=None):
31 | """Initializes the Caption.
32 |
33 | Args:
34 | sentence: List of word ids in the caption.
35 | state: Model state after generating the previous word.
36 | logprob: Log-probability of the caption.
37 | score: Score of the caption.
38 | metadata: Optional metadata associated with the partial sentence. If not
39 | None, a list of strings with the same length as 'sentence'.
40 | """
41 | self.sentence = sentence
42 | self.state = state
43 | self.logprob = logprob
44 | self.score = score
45 | self.metadata = metadata
46 |
47 | def __cmp__(self, other):
48 | """Compares Captions by score."""
49 | if not isinstance(other, Caption):
50 | raise ValueError("'%s' isn't an instance of Caption" % other)
51 | if self.score == other.score:
52 | return 0
53 | elif self.score < other.score:
54 | return -1
55 | else:
56 | return 1
57 |
58 | # For Python 3 compatibility (__cmp__ is deprecated).
59 | def __lt__(self, other):
60 | if not isinstance(other, Caption):
61 | raise ValueError("'%s' isn't an instance of Caption" % other)
62 | return self.score < other.score
63 |
64 | # Also for Python 3 compatibility.
65 | def __eq__(self, other):
66 | if not isinstance(other, Caption):
67 | raise ValueError("'%s' isn't an instance of Caption" % other)
68 | return self.score == other.score
69 |
70 |
71 | class TopN(object):
72 | """Maintains the top n elements of an incrementally provided set."""
73 |
74 | def __init__(self, n):
75 | self._n = n
76 | self._data = []
77 |
78 | def size(self):
79 | if self._data is None:
80 | raise ValueError("'self._data' is None")
81 | return len(self._data)
82 |
83 | def push(self, x):
84 | """Pushes a new element."""
85 | if self._data is None:
86 | raise ValueError("'self._data' is none")
87 | if len(self._data) < self._n:
88 | heapq.heappush(self._data, x)
89 | else:
90 | heapq.heappushpop(self._data, x)
91 |
92 | def extract(self, sort=False):
93 | """Extracts all elements from the TopN. This is a destructive operation.
94 |
95 | The only method that can be called immediately after extract() is reset().
96 |
97 | Args:
98 | sort: Whether to return the elements in descending sorted order.
99 |
100 | Returns:
101 | A list of data; the top n elements provided to the set.
102 | """
103 | if self._data is None:
104 | raise ValueError("'self._data' is None")
105 | data = self._data
106 | self._data = None
107 | if sort:
108 | data.sort(reverse=True)
109 | return data
110 |
111 | def reset(self):
112 | """Returns the TopN to an empty state."""
113 | self._data = []
114 |
115 |
116 | class CaptionGenerator(object):
117 | """Class to generate captions from an image-to-text model."""
118 |
119 | def __init__(self,
120 | model,
121 | vocab,
122 | beam_size=3,
123 | max_caption_length=20,
124 | length_normalization_factor=0.0):
125 | """Initializes the generator.
126 |
127 | Args:
128 | model: Object encapsulating a trained image-to-text model. Must have
129 | methods feed_image() and inference_step(). For example, an instance of
130 | InferenceWrapperBase.
131 | vocab: A Vocabulary object.
132 | beam_size: Beam size to use when generating captions.
133 | max_caption_length: The maximum caption length before stopping the search.
134 | length_normalization_factor: If != 0, a number x such that captions are
135 | scored by logprob/length^x, rather than logprob. This changes the
136 | relative scores of captions depending on their lengths. For example, if
137 | x > 0 then longer captions will be favored.
138 | """
139 | self.vocab = vocab
140 | self.model = model
141 |
142 | self.beam_size = beam_size
143 | self.max_caption_length = max_caption_length
144 | self.length_normalization_factor = length_normalization_factor
145 |
146 | def beam_search(self, sess, encoded_image):
147 | """Runs beam search caption generation on a single image.
148 |
149 | Args:
150 | sess: TensorFlow Session object.
151 | encoded_image: An encoded image string.
152 |
153 | Returns:
154 | A list of Caption sorted by descending score.
155 | """
156 | # Feed in the image to get the initial state.
157 | initial_state = self.model.feed_image(sess, encoded_image)
158 |
159 | initial_beam = Caption(
160 | sentence=[self.vocab.start_id],
161 | state=initial_state[0],
162 | logprob=0.0,
163 | score=0.0,
164 | metadata=[""])
165 | partial_captions = TopN(self.beam_size)
166 | partial_captions.push(initial_beam)
167 | complete_captions = TopN(self.beam_size)
168 |
169 | # Run beam search.
170 | for _ in range(self.max_caption_length - 1):
171 | partial_captions_list = partial_captions.extract()
172 | partial_captions.reset()
173 | input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
174 | state_feed = np.array([c.state for c in partial_captions_list])
175 |
176 | softmax, new_states, metadata = self.model.inference_step(sess,
177 | input_feed,
178 | state_feed)
179 |
180 | for i, partial_caption in enumerate(partial_captions_list):
181 | word_probabilities = softmax[i]
182 | state = new_states[i]
183 | # For this partial caption, get the beam_size most probable next words.
184 | words_and_probs = list(enumerate(word_probabilities))
185 | words_and_probs.sort(key=lambda x: -x[1])
186 | words_and_probs = words_and_probs[0:self.beam_size]
187 | # Each next word gives a new partial caption.
188 | for w, p in words_and_probs:
189 | if p < 1e-12:
190 | continue # Avoid log(0).
191 | sentence = partial_caption.sentence + [w]
192 | logprob = partial_caption.logprob + math.log(p)
193 | score = logprob
194 | if metadata:
195 | metadata_list = partial_caption.metadata + [metadata[i]]
196 | else:
197 | metadata_list = None
198 | if w == self.vocab.end_id:
199 | if self.length_normalization_factor > 0:
200 | score /= len(sentence) ** self.length_normalization_factor
201 | beam = Caption(sentence, state, logprob, score, metadata_list)
202 | complete_captions.push(beam)
203 | else:
204 | beam = Caption(sentence, state, logprob, score, metadata_list)
205 | partial_captions.push(beam)
206 | if partial_captions.size() == 0:
207 | # We have run out of partial candidates; happens when beam_size = 1.
208 | break
209 |
210 | # If we have no complete captions then fall back to the partial captions.
211 | # But never output a mixture of complete and partial captions because a
212 | # partial caption could have a higher score than all the complete captions.
213 | if not complete_captions.size():
214 | complete_captions = partial_captions
215 |
216 | return complete_captions.extract(sort=True)
217 |
--------------------------------------------------------------------------------
/core/inference_utils/inference_wrapper_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Base wrapper class for performing inference with an image-to-text model.
16 |
17 | Subclasses must implement the following methods:
18 |
19 | build_model():
20 | Builds the model for inference and returns the model object.
21 |
22 | feed_image():
23 | Takes an encoded image and returns the initial model state, where "state"
24 | is a numpy array whose specifics are defined by the subclass, e.g.
25 | concatenated LSTM state. It's assumed that feed_image() will be called
26 | precisely once at the start of inference for each image. Subclasses may
27 | compute and/or save per-image internal context in this method.
28 |
29 | inference_step():
30 | Takes a batch of inputs and states at a single time-step. Returns the
31 | softmax output corresponding to the inputs, and the new states of the batch.
32 | Optionally also returns metadata about the current inference step, e.g. a
33 | serialized numpy array containing activations from a particular model layer.
34 |
35 | Client usage:
36 | 1. Build the model inference graph via build_graph_from_config() or
37 | build_graph_from_proto().
38 | 2. Call the resulting restore_fn to load the model checkpoint.
39 | 3. For each image in a batch of images:
40 | a) Call feed_image() once to get the initial state.
41 | b) For each step of caption generation, call inference_step().
42 | """
43 |
44 | from __future__ import absolute_import
45 | from __future__ import division
46 | from __future__ import print_function
47 |
48 | import os.path
49 |
50 | import tensorflow as tf
51 |
52 |
53 | # pylint: disable=unused-argument
54 |
55 |
56 | class InferenceWrapperBase(object):
57 | """Base wrapper class for performing inference with an image-to-text model."""
58 |
59 | def __init__(self):
60 | pass
61 |
62 | def build_model(self, model_config):
63 | """Builds the model for inference.
64 |
65 | Args:
66 | model_config: Object containing configuration for building the model.
67 |
68 | Returns:
69 | model: The model object.
70 | """
71 | tf.compat.v1.logging.fatal("Please implement build_model in subclass")
72 |
73 | def _create_restore_fn(self, checkpoint_path, saver):
74 | """Creates a function that restores a model from checkpoint.
75 |
76 | Args:
77 | checkpoint_path: Checkpoint file or a directory containing a checkpoint
78 | file.
79 | saver: Saver for restoring variables from the checkpoint file.
80 |
81 | Returns:
82 | restore_fn: A function such that restore_fn(sess) loads model variables
83 | from the checkpoint file.
84 |
85 | Raises:
86 | ValueError: If checkpoint_path does not refer to a checkpoint file or a
87 | directory containing a checkpoint file.
88 | """
89 | if tf.compat.v1.gfile.IsDirectory(checkpoint_path):
90 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
91 | if not checkpoint_path:
92 | raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
93 |
94 | def _restore_fn(sess):
95 | tf.compat.v1.logging.info("Loading model from checkpoint: %s", checkpoint_path)
96 | saver.restore(sess, checkpoint_path)
97 | tf.compat.v1.logging.info("Successfully loaded checkpoint: %s",
98 | os.path.basename(checkpoint_path))
99 |
100 | return _restore_fn
101 |
102 | def build_graph_from_config(self, model_config, checkpoint_path):
103 | """Builds the inference graph from a configuration object.
104 |
105 | Args:
106 | model_config: Object containing configuration for building the model.
107 | checkpoint_path: Checkpoint file or a directory containing a checkpoint
108 | file.
109 |
110 | Returns:
111 | restore_fn: A function such that restore_fn(sess) loads model variables
112 | from the checkpoint file.
113 | """
114 | tf.compat.v1.logging.info("Building model.")
115 | self.build_model(model_config)
116 | saver = tf.compat.v1.train.Saver()
117 |
118 | return self._create_restore_fn(checkpoint_path, saver)
119 |
120 | def build_graph_from_proto(self, graph_def_file, saver_def_file,
121 | checkpoint_path):
122 | """Builds the inference graph from serialized GraphDef and SaverDef protos.
123 |
124 | Args:
125 | graph_def_file: File containing a serialized GraphDef proto.
126 | saver_def_file: File containing a serialized SaverDef proto.
127 | checkpoint_path: Checkpoint file or a directory containing a checkpoint
128 | file.
129 |
130 | Returns:
131 | restore_fn: A function such that restore_fn(sess) loads model variables
132 | from the checkpoint file.
133 | """
134 | # Load the Graph.
135 | tf.compat.v1.logging.info("Loading GraphDef from file: %s", graph_def_file)
136 | graph_def = tf.GraphDef()
137 | with tf.compat.v1.gfile.FastGFile(graph_def_file, "rb") as f:
138 | graph_def.ParseFromString(f.read())
139 | tf.import_graph_def(graph_def, name="")
140 |
141 | # Load the Saver.
142 | tf.compat.v1.logging.info("Loading SaverDef from file: %s", saver_def_file)
143 | saver_def = tf.compat.v1.train.SaverDef()
144 | with tf.compat.v1.gfile.FastGFile(saver_def_file, "rb") as f:
145 | saver_def.ParseFromString(f.read())
146 | saver = tf.compat.v1.train.Saver(saver_def=saver_def)
147 |
148 | return self._create_restore_fn(checkpoint_path, saver)
149 |
150 | def feed_image(self, sess, encoded_image):
151 | """Feeds an image and returns the initial model state.
152 |
153 | See comments at the top of file.
154 |
155 | Args:
156 | sess: TensorFlow Session object.
157 | encoded_image: An encoded image string.
158 |
159 | Returns:
160 | state: A numpy array of shape [1, state_size].
161 | """
162 | tf.compat.v1.logging.fatal("Please implement feed_image in subclass")
163 |
164 | def inference_step(self, sess, input_feed, state_feed):
165 | """Runs one step of inference.
166 |
167 | Args:
168 | sess: TensorFlow Session object.
169 | input_feed: A numpy array of shape [batch_size].
170 | state_feed: A numpy array of shape [batch_size, state_size].
171 |
172 | Returns:
173 | softmax_output: A numpy array of shape [batch_size, vocab_size].
174 | new_state: A numpy array of shape [batch_size, state_size].
175 | metadata: Optional. If not None, a string containing metadata about the
176 | current inference step (e.g. serialized numpy array containing
177 | activations from a particular model layer.).
178 | """
179 | tf.compat.v1.logging.fatal("Please implement inference_step in subclass")
180 |
181 | # pylint: enable=unused-argument
182 |
--------------------------------------------------------------------------------
/core/inference_utils/vocabulary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Vocabulary class for an image-to-text model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 |
24 | class Vocabulary(object):
25 | """Vocabulary class for an image-to-text model."""
26 |
27 | def __init__(self,
28 | vocab_file,
29 | start_word="",
30 | end_word="",
31 | unk_word=""):
32 | """Initializes the vocabulary.
33 |
34 | Args:
35 | vocab_file: File containing the vocabulary, where the words are the first
36 | whitespace-separated token on each line (other tokens are ignored) and
37 | the word ids are the corresponding line numbers.
38 | start_word: Special word denoting sentence start.
39 | end_word: Special word denoting sentence end.
40 | unk_word: Special word denoting unknown words.
41 | """
42 | if not tf.compat.v1.gfile.Exists(vocab_file):
43 | tf.compat.v1.logging.fatal("Vocab file %s not found.", vocab_file)
44 | tf.compat.v1.logging.info("Initializing vocabulary from file: %s", vocab_file)
45 |
46 | with tf.compat.v1.gfile.GFile(vocab_file, mode="r") as f:
47 | reverse_vocab = list(f.readlines())
48 | reverse_vocab = [line.split()[0] for line in reverse_vocab]
49 | if start_word not in reverse_vocab:
50 | raise ValueError("Start word '%s' is not in 'reserved_vocab" % start_word)
51 | if end_word not in reverse_vocab:
52 | raise ValueError("End word '%s' is not in 'reserved_vocab" % end_word)
53 | if unk_word not in reverse_vocab:
54 | reverse_vocab.append(unk_word)
55 | vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
56 |
57 | tf.compat.v1.logging.info("Created vocabulary with %d words" % len(vocab))
58 |
59 | self.vocab = vocab # vocab[word] = id
60 | self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word
61 |
62 | # Save special word ids.
63 | self.start_id = vocab[start_word]
64 | self.end_id = vocab[end_word]
65 | self.unk_id = vocab[unk_word]
66 |
67 | def word_to_id(self, word):
68 | """Returns the integer word id of a word string."""
69 | if word in self.vocab:
70 | return self.vocab[word]
71 | else:
72 | return self.unk_id
73 |
74 | def id_to_word(self, word_id):
75 | """Returns the word string of an integer word id."""
76 | if word_id >= len(self.reverse_vocab):
77 | return self.reverse_vocab[self.unk_id]
78 | else:
79 | return self.reverse_vocab[word_id]
80 |
--------------------------------------------------------------------------------
/core/inference_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Model wrapper class for performing inference with a ShowAndTellModel."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from core import show_and_tell_model
23 | from core.inference_utils import inference_wrapper_base
24 |
25 |
26 | class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
27 | """Model wrapper class for performing inference with a ShowAndTellModel."""
28 |
29 | def __init__(self):
30 | super(InferenceWrapper, self).__init__()
31 |
32 | def build_model(self, model_config):
33 | model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
34 | model.build()
35 | return model
36 |
37 | def feed_image(self, sess, encoded_image):
38 | initial_state = sess.run(fetches="lstm/initial_state:0",
39 | feed_dict={"image_feed:0": encoded_image})
40 | return initial_state
41 |
42 | def inference_step(self, sess, input_feed, state_feed):
43 | softmax_output, state_output = sess.run(
44 | fetches=["softmax:0", "lstm/state:0"],
45 | feed_dict={
46 | "input_feed:0": input_feed,
47 | "lstm/state_feed:0": state_feed,
48 | })
49 | return softmax_output, state_output, None
50 |
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from maxfw.model import MAXModelWrapper
18 |
19 | import math
20 | import logging
21 |
22 | import tensorflow as tf
23 |
24 | from core import configuration
25 | from core import inference_wrapper
26 | from core.inference_utils import vocabulary
27 | from core.inference_utils import caption_generator
28 |
29 | from config import DEFAULT_MODEL_PATH, VOCAB_FILE
30 |
31 | logger = logging.getLogger()
32 |
33 | tf.compat.v1.disable_eager_execution()
34 |
35 |
36 | class ModelWrapper(MAXModelWrapper):
37 |
38 | def __init__(self, path=DEFAULT_MODEL_PATH):
39 | # TODO Replace this part with SavedModel
40 | g = tf.Graph()
41 | with g.as_default():
42 | model = inference_wrapper.InferenceWrapper()
43 | restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
44 | path)
45 | g.finalize()
46 | self.model = model
47 | sess = tf.compat.v1.Session(graph=g)
48 | # Load the model from checkpoint.
49 | restore_fn(sess)
50 | self.sess = sess
51 |
52 | def _predict(self, image_data):
53 | # Create the vocabulary.
54 | vocab = vocabulary.Vocabulary(VOCAB_FILE)
55 |
56 | # Prepare the caption generator. Here we are implicitly using the default
57 | # beam search parameters. See caption_generator.py for a description of the
58 | # available beam search parameters.
59 | generator = caption_generator.CaptionGenerator(self.model, vocab)
60 |
61 | captions = generator.beam_search(self.sess, image_data)
62 |
63 | results = []
64 | for i, caption in enumerate(captions):
65 | # Ignore begin and end words.
66 | sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
67 | sentence = " ".join(sentence)
68 | # print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
69 | results.append((i, sentence, math.exp(caption.logprob)))
70 |
71 | return results
72 |
--------------------------------------------------------------------------------
/core/ops/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//im2txt:internal"])
2 |
3 | licenses(["notice"]) # Apache 2.0
4 |
5 | exports_files(["LICENSE"])
6 |
7 | py_library(
8 | name = "image_processing",
9 | srcs = ["image_processing.py"],
10 | srcs_version = "PY2AND3",
11 | )
12 |
13 | py_library(
14 | name = "image_embedding",
15 | srcs = ["image_embedding.py"],
16 | srcs_version = "PY2AND3",
17 | )
18 |
19 | py_test(
20 | name = "image_embedding_test",
21 | size = "small",
22 | srcs = ["image_embedding_test.py"],
23 | deps = [
24 | ":image_embedding",
25 | ],
26 | )
27 |
28 | py_library(
29 | name = "inputs",
30 | srcs = ["inputs.py"],
31 | srcs_version = "PY2AND3",
32 | )
33 |
--------------------------------------------------------------------------------
/core/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/core/ops/__init__.py
--------------------------------------------------------------------------------
/core/ops/image_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Image embedding ops."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 | import tf_slim as slim
25 | from tf_slim.nets.inception_v3 import inception_v3_base
26 |
27 |
28 | def inception_v3(images,
29 | trainable=True,
30 | is_training=True,
31 | weight_decay=0.00004,
32 | stddev=0.1,
33 | dropout_keep_prob=0.8,
34 | use_batch_norm=True,
35 | batch_norm_params=None,
36 | add_summaries=True,
37 | scope="InceptionV3"):
38 | """Builds an Inception V3 subgraph for image embeddings.
39 |
40 | Args:
41 | images: A float32 Tensor of shape [batch, height, width, channels].
42 | trainable: Whether the inception submodel should be trainable or not.
43 | is_training: Boolean indicating training mode or not.
44 | weight_decay: Coefficient for weight regularization.
45 | stddev: The standard deviation of the trunctated normal weight initializer.
46 | dropout_keep_prob: Dropout keep probability.
47 | use_batch_norm: Whether to use batch normalization.
48 | batch_norm_params: Parameters for batch normalization. See
49 | tf.contrib.layers.batch_norm for details.
50 | add_summaries: Whether to add activation summaries.
51 | scope: Optional Variable scope.
52 |
53 | Returns:
54 | end_points: A dictionary of activations from inception_v3 layers.
55 | """
56 | # Only consider the inception model to be in training mode if it's trainable.
57 | is_inception_model_training = trainable and is_training
58 |
59 | if use_batch_norm:
60 | # Default parameters for batch normalization.
61 | if not batch_norm_params:
62 | batch_norm_params = {
63 | "is_training": is_inception_model_training,
64 | "trainable": trainable,
65 | # Decay for the moving averages.
66 | "decay": 0.9997,
67 | # Epsilon to prevent 0s in variance.
68 | "epsilon": 0.001,
69 | # Collection containing the moving mean and moving variance.
70 | "variables_collections": {
71 | "beta": None,
72 | "gamma": None,
73 | "moving_mean": ["moving_vars"],
74 | "moving_variance": ["moving_vars"],
75 | }
76 | }
77 | else:
78 | batch_norm_params = None
79 |
80 | if trainable:
81 | weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
82 | else:
83 | weights_regularizer = None
84 |
85 | with tf.compat.v1.variable_scope(scope, "InceptionV3", [images]) as scope:
86 | with slim.arg_scope(
87 | [slim.conv2d, slim.fully_connected],
88 | weights_regularizer=weights_regularizer,
89 | trainable=trainable):
90 | with slim.arg_scope(
91 | [slim.conv2d],
92 | weights_initializer=tf.compat.v1.truncated_normal_initializer(stddev=stddev),
93 | activation_fn=tf.nn.relu,
94 | normalizer_fn=slim.batch_norm,
95 | normalizer_params=batch_norm_params):
96 | net, end_points = inception_v3_base(images, scope=scope)
97 | with tf.compat.v1.variable_scope("logits"):
98 | shape = net.get_shape()
99 | net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
100 | net = slim.dropout(
101 | net,
102 | keep_prob=dropout_keep_prob,
103 | is_training=is_inception_model_training,
104 | scope="dropout")
105 | net = slim.flatten(net, scope="flatten")
106 |
107 | # Add summaries.
108 | if add_summaries:
109 | for v in end_points.values():
110 | slim.summarize_activation(v)
111 |
112 | return net
113 |
--------------------------------------------------------------------------------
/core/ops/image_processing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helper functions for image preprocessing."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | def distort_image(image, thread_id):
26 | """Perform random distortions on an image.
27 |
28 | Args:
29 | image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
30 | thread_id: Preprocessing thread id used to select the ordering of color
31 | distortions. There should be a multiple of 2 preprocessing threads.
32 |
33 | Returns:
34 | distorted_image: A float32 Tensor of shape [height, width, 3] with values in
35 | [0, 1].
36 | """
37 | # Randomly flip horizontally.
38 | with tf.name_scope("flip_horizontal"):
39 | image = tf.image.random_flip_left_right(image)
40 |
41 | # Randomly distort the colors based on thread id.
42 | color_ordering = thread_id % 2
43 | with tf.name_scope("distort_color"):
44 | if color_ordering == 0:
45 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
46 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
47 | image = tf.image.random_hue(image, max_delta=0.032)
48 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
49 | elif color_ordering == 1:
50 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
51 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
52 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
53 | image = tf.image.random_hue(image, max_delta=0.032)
54 |
55 | # The random_* ops do not necessarily clamp.
56 | image = tf.clip_by_value(image, 0.0, 1.0)
57 |
58 | return image
59 |
60 |
61 | def process_image(encoded_image,
62 | is_training,
63 | height,
64 | width,
65 | resize_height=346,
66 | resize_width=346,
67 | thread_id=0,
68 | image_format="jpeg"):
69 | """Decode an image, resize and apply random distortions.
70 |
71 | In training, images are distorted slightly differently depending on thread_id.
72 |
73 | Args:
74 | encoded_image: String Tensor containing the image.
75 | is_training: Boolean; whether preprocessing for training or eval.
76 | height: Height of the output image.
77 | width: Width of the output image.
78 | resize_height: If > 0, resize height before crop to final dimensions.
79 | resize_width: If > 0, resize width before crop to final dimensions.
80 | thread_id: Preprocessing thread id used to select the ordering of color
81 | distortions. There should be a multiple of 2 preprocessing threads.
82 | image_format: "jpeg" or "png".
83 |
84 | Returns:
85 | A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
86 |
87 | Raises:
88 | ValueError: If image_format is invalid.
89 | """
90 |
91 | # Helper function to log an image summary to the visualizer. Summaries are
92 | # only logged in thread 0.
93 | def image_summary(name, image):
94 | if not thread_id:
95 | tf.summary.image(name, tf.expand_dims(image, 0))
96 |
97 | # Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
98 | with tf.name_scope("decode"):
99 | if image_format == "jpeg":
100 | image = tf.image.decode_jpeg(encoded_image, channels=3)
101 | elif image_format == "png":
102 | image = tf.image.decode_png(encoded_image, channels=3)
103 | else:
104 | raise ValueError("Invalid image format: %s" % image_format)
105 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
106 | image_summary("original_image", image)
107 |
108 | # Resize image.
109 | if (resize_height > 0) != (resize_width > 0):
110 | raise ValueError("Invalid resize parameters height: '{0}' width: '{1}'".format(resize_height, resize_width))
111 |
112 | if resize_height:
113 | image = tf.image.resize(image,
114 | size=[resize_height, resize_width],
115 | method=tf.image.ResizeMethod.BILINEAR)
116 |
117 | # Crop to final dimensions.
118 | if is_training:
119 | image = tf.random_crop(image, [height, width, 3])
120 | else:
121 | # Central crop, assuming resize_height > height, resize_width > width.
122 | image = tf.image.resize_with_crop_or_pad(image, height, width)
123 |
124 | image_summary("resized_image", image)
125 |
126 | # Randomly distort the image.
127 | if is_training:
128 | image = distort_image(image, thread_id)
129 |
130 | image_summary("final_image", image)
131 |
132 | # Rescale to [-1,1] instead of [0, 1]
133 | image = tf.subtract(image, 0.5)
134 | image = tf.multiply(image, 2.0)
135 | return image
136 |
--------------------------------------------------------------------------------
/core/ops/inputs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Input ops."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | def parse_sequence_example(serialized, image_feature, caption_feature):
26 | """Parses a tensorflow.SequenceExample into an image and caption.
27 |
28 | Args:
29 | serialized: A scalar string Tensor; a single serialized SequenceExample.
30 | image_feature: Name of SequenceExample context feature containing image
31 | data.
32 | caption_feature: Name of SequenceExample feature list containing integer
33 | captions.
34 |
35 | Returns:
36 | encoded_image: A scalar string Tensor containing a JPEG encoded image.
37 | caption: A 1-D uint64 Tensor with dynamically specified length.
38 | """
39 | context, sequence = tf.parse_single_sequence_example(
40 | serialized,
41 | context_features={
42 | image_feature: tf.FixedLenFeature([], dtype=tf.string)
43 | },
44 | sequence_features={
45 | caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64),
46 | })
47 |
48 | encoded_image = context[image_feature]
49 | caption = sequence[caption_feature]
50 | return encoded_image, caption
51 |
52 |
53 | def prefetch_input_data(reader,
54 | file_pattern,
55 | is_training,
56 | batch_size,
57 | values_per_shard,
58 | input_queue_capacity_factor=16,
59 | num_reader_threads=1,
60 | shard_queue_name="filename_queue",
61 | value_queue_name="input_queue"):
62 | """Prefetches string values from disk into an input queue.
63 |
64 | In training the capacity of the queue is important because a larger queue
65 | means better mixing of training examples between shards. The minimum number of
66 | values kept in the queue is values_per_shard * input_queue_capacity_factor,
67 | where input_queue_memory factor should be chosen to trade-off better mixing
68 | with memory usage.
69 |
70 | Args:
71 | reader: Instance of tf.ReaderBase.
72 | file_pattern: Comma-separated list of file patterns (e.g.
73 | /tmp/train_data-?????-of-00100).
74 | is_training: Boolean; whether prefetching for training or eval.
75 | batch_size: Model batch size used to determine queue capacity.
76 | values_per_shard: Approximate number of values per shard.
77 | input_queue_capacity_factor: Minimum number of values to keep in the queue
78 | in multiples of values_per_shard. See comments above.
79 | num_reader_threads: Number of reader threads to fill the queue.
80 | shard_queue_name: Name for the shards filename queue.
81 | value_queue_name: Name for the values input queue.
82 |
83 | Returns:
84 | A Queue containing prefetched string values.
85 | """
86 | data_files = []
87 | for pattern in file_pattern.split(","):
88 | data_files.extend(tf.compat.v1.gfile.Glob(pattern))
89 | if not data_files:
90 | tf.logging.fatal("Found no input files matching %s", file_pattern)
91 | else:
92 | tf.logging.info("Prefetching values from %d files matching %s",
93 | len(data_files), file_pattern)
94 |
95 | if is_training:
96 | filename_queue = tf.train.string_input_producer(
97 | data_files, shuffle=True, capacity=16, name=shard_queue_name)
98 | min_queue_examples = values_per_shard * input_queue_capacity_factor
99 | capacity = min_queue_examples + 100 * batch_size
100 | values_queue = tf.RandomShuffleQueue(
101 | capacity=capacity,
102 | min_after_dequeue=min_queue_examples,
103 | dtypes=[tf.string],
104 | name="random_" + value_queue_name)
105 | else:
106 | filename_queue = tf.train.string_input_producer(
107 | data_files, shuffle=False, capacity=1, name=shard_queue_name)
108 | capacity = values_per_shard + 3 * batch_size
109 | values_queue = tf.FIFOQueue(
110 | capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
111 |
112 | enqueue_ops = []
113 | for _ in range(num_reader_threads):
114 | _, value = reader.read(filename_queue)
115 | enqueue_ops.append(values_queue.enqueue([value]))
116 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
117 | values_queue, enqueue_ops))
118 | tf.summary.scalar(
119 | "queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
120 | tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
121 |
122 | return values_queue
123 |
124 |
125 | def batch_with_dynamic_pad(images_and_captions,
126 | batch_size,
127 | queue_capacity,
128 | add_summaries=True):
129 | """Batches input images and captions.
130 |
131 | This function splits the caption into an input sequence and a target sequence,
132 | where the target sequence is the input sequence right-shifted by 1. Input and
133 | target sequences are batched and padded up to the maximum length of sequences
134 | in the batch. A mask is created to distinguish real words from padding words.
135 |
136 | Example:
137 | Actual captions in the batch ('-' denotes padded character):
138 | [
139 | [ 1 2 3 4 5 ],
140 | [ 1 2 3 4 - ],
141 | [ 1 2 3 - - ],
142 | ]
143 |
144 | input_seqs:
145 | [
146 | [ 1 2 3 4 ],
147 | [ 1 2 3 - ],
148 | [ 1 2 - - ],
149 | ]
150 |
151 | target_seqs:
152 | [
153 | [ 2 3 4 5 ],
154 | [ 2 3 4 - ],
155 | [ 2 3 - - ],
156 | ]
157 |
158 | mask:
159 | [
160 | [ 1 1 1 1 ],
161 | [ 1 1 1 0 ],
162 | [ 1 1 0 0 ],
163 | ]
164 |
165 | Args:
166 | images_and_captions: A list of pairs [image, caption], where image is a
167 | Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
168 | any length. Each pair will be processed and added to the queue in a
169 | separate thread.
170 | batch_size: Batch size.
171 | queue_capacity: Queue capacity.
172 | add_summaries: If true, add caption length summaries.
173 |
174 | Returns:
175 | images: A Tensor of shape [batch_size, height, width, channels].
176 | input_seqs: An int32 Tensor of shape [batch_size, padded_length].
177 | target_seqs: An int32 Tensor of shape [batch_size, padded_length].
178 | mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
179 | """
180 | enqueue_list = []
181 | for image, caption in images_and_captions:
182 | caption_length = tf.shape(caption)[0]
183 | input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
184 |
185 | input_seq = tf.slice(caption, [0], input_length)
186 | target_seq = tf.slice(caption, [1], input_length)
187 | indicator = tf.ones(input_length, dtype=tf.int32)
188 | enqueue_list.append([image, input_seq, target_seq, indicator])
189 |
190 | images, input_seqs, target_seqs, mask = tf.train.batch_join(
191 | enqueue_list,
192 | batch_size=batch_size,
193 | capacity=queue_capacity,
194 | dynamic_pad=True,
195 | name="batch_and_pad")
196 |
197 | if add_summaries:
198 | lengths = tf.add(tf.reduce_sum(mask, 1), 1)
199 | tf.summary.scalar("caption_length/batch_min", tf.reduce_min(lengths))
200 | tf.summary.scalar("caption_length/batch_max", tf.reduce_max(lengths))
201 | tf.summary.scalar("caption_length/batch_mean", tf.reduce_mean(lengths))
202 |
203 | return images, input_seqs, target_seqs, mask
204 |
--------------------------------------------------------------------------------
/core/show_and_tell_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
17 |
18 | "Show and Tell: A Neural Image Caption Generator"
19 | Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import tensorflow as tf
27 | import tf_slim as slim
28 |
29 | from core.ops import image_embedding
30 | from core.ops import image_processing
31 | from core.ops import inputs as input_ops
32 |
33 |
34 | class ShowAndTellModel(object):
35 | """Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
36 |
37 | "Show and Tell: A Neural Image Caption Generator"
38 | Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
39 | """
40 |
41 | def __init__(self, config, mode, train_inception=False):
42 | """Basic setup.
43 |
44 | Args:
45 | config: Object containing configuration parameters.
46 | mode: "train", "eval" or "inference".
47 | train_inception: Whether the inception submodel variables are trainable.
48 | """
49 | if mode not in ["train", "eval", "inference"]:
50 | raise ValueError("Variable 'mode' is invalid: %s" % mode)
51 | self.config = config
52 | self.mode = mode
53 | self.train_inception = train_inception
54 |
55 | # Reader for the input data.
56 | self.reader = tf.compat.v1.TFRecordReader()
57 |
58 | # To match the "Show and Tell" paper we initialize all variables with a
59 | # random uniform initializer.
60 | self.initializer = tf.random_uniform_initializer(
61 | minval=-self.config.initializer_scale,
62 | maxval=self.config.initializer_scale)
63 |
64 | # A float32 Tensor with shape [batch_size, height, width, channels].
65 | self.images = None
66 |
67 | # An int32 Tensor with shape [batch_size, padded_length].
68 | self.input_seqs = None
69 |
70 | # An int32 Tensor with shape [batch_size, padded_length].
71 | self.target_seqs = None
72 |
73 | # An int32 0/1 Tensor with shape [batch_size, padded_length].
74 | self.input_mask = None
75 |
76 | # A float32 Tensor with shape [batch_size, embedding_size].
77 | self.image_embeddings = None
78 |
79 | # A float32 Tensor with shape [batch_size, padded_length, embedding_size].
80 | self.seq_embeddings = None
81 |
82 | # A float32 scalar Tensor; the total loss for the trainer to optimize.
83 | self.total_loss = None
84 |
85 | # A float32 Tensor with shape [batch_size * padded_length].
86 | self.target_cross_entropy_losses = None
87 |
88 | # A float32 Tensor with shape [batch_size * padded_length].
89 | self.target_cross_entropy_loss_weights = None
90 |
91 | # Collection of variables from the inception submodel.
92 | self.inception_variables = []
93 |
94 | # Function to restore the inception submodel from checkpoint.
95 | self.init_fn = None
96 |
97 | # Global step Tensor.
98 | self.global_step = None
99 |
100 | def is_training(self):
101 | """Returns true if the model is built for training mode."""
102 | return self.mode == "train"
103 |
104 | def process_image(self, encoded_image, thread_id=0):
105 | """Decodes and processes an image string.
106 |
107 | Args:
108 | encoded_image: A scalar string Tensor; the encoded image.
109 | thread_id: Preprocessing thread id used to select the ordering of color
110 | distortions.
111 |
112 | Returns:
113 | A float32 Tensor of shape [height, width, 3]; the processed image.
114 | """
115 | return image_processing.process_image(encoded_image,
116 | is_training=self.is_training(),
117 | height=self.config.image_height,
118 | width=self.config.image_width,
119 | thread_id=thread_id,
120 | image_format=self.config.image_format)
121 |
122 | def build_inputs(self):
123 | """Input prefetching, preprocessing and batching.
124 |
125 | Outputs:
126 | self.images
127 | self.input_seqs
128 | self.target_seqs (training and eval only)
129 | self.input_mask (training and eval only)
130 | """
131 | if self.mode == "inference":
132 | # In inference mode, images and inputs are fed via placeholders.
133 | image_feed = tf.compat.v1.placeholder(dtype=tf.string, shape=[], name="image_feed")
134 | input_feed = tf.compat.v1.placeholder(dtype=tf.int64,
135 | shape=[None], # batch_size
136 | name="input_feed")
137 |
138 | # Process image and insert batch dimensions.
139 | images = tf.expand_dims(self.process_image(image_feed), 0)
140 | input_seqs = tf.expand_dims(input_feed, 1)
141 |
142 | # No target sequences or input mask in inference mode.
143 | target_seqs = None
144 | input_mask = None
145 | else:
146 | # Prefetch serialized SequenceExample protos.
147 | input_queue = input_ops.prefetch_input_data(
148 | self.reader,
149 | self.config.input_file_pattern,
150 | is_training=self.is_training(),
151 | batch_size=self.config.batch_size,
152 | values_per_shard=self.config.values_per_input_shard,
153 | input_queue_capacity_factor=self.config.input_queue_capacity_factor,
154 | num_reader_threads=self.config.num_input_reader_threads)
155 |
156 | # Image processing and random distortion. Split across multiple threads
157 | # with each thread applying a slightly different distortion.
158 | if self.config.num_preprocess_threads % 2 != 0:
159 | raise ValueError("'self.config.num_preprocess_threads' is invalid: %s" % self.config.num_preprocess_threads)
160 | images_and_captions = []
161 | for thread_id in range(self.config.num_preprocess_threads):
162 | serialized_sequence_example = input_queue.dequeue()
163 | encoded_image, caption = input_ops.parse_sequence_example(
164 | serialized_sequence_example,
165 | image_feature=self.config.image_feature_name,
166 | caption_feature=self.config.caption_feature_name)
167 | image = self.process_image(encoded_image, thread_id=thread_id)
168 | images_and_captions.append([image, caption])
169 |
170 | # Batch inputs.
171 | queue_capacity = (2 * self.config.num_preprocess_threads *
172 | self.config.batch_size)
173 | images, input_seqs, target_seqs, input_mask = (
174 | input_ops.batch_with_dynamic_pad(images_and_captions,
175 | batch_size=self.config.batch_size,
176 | queue_capacity=queue_capacity))
177 |
178 | self.images = images
179 | self.input_seqs = input_seqs
180 | self.target_seqs = target_seqs
181 | self.input_mask = input_mask
182 |
183 | def build_image_embeddings(self):
184 | """Builds the image model subgraph and generates image embeddings.
185 |
186 | Inputs:
187 | self.images
188 |
189 | Outputs:
190 | self.image_embeddings
191 | """
192 | inception_output = image_embedding.inception_v3(
193 | self.images,
194 | trainable=self.train_inception,
195 | is_training=self.is_training())
196 | self.inception_variables = tf.compat.v1.get_collection(
197 | tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
198 |
199 | # Map inception output into embedding space.
200 | with tf.compat.v1.variable_scope("image_embedding") as scope:
201 | image_embeddings = slim.layers.fully_connected(
202 | inputs=inception_output,
203 | num_outputs=self.config.embedding_size,
204 | activation_fn=None,
205 | weights_initializer=self.initializer,
206 | biases_initializer=None,
207 | scope=scope)
208 |
209 | # Save the embedding size in the graph.
210 | tf.constant(self.config.embedding_size, name="embedding_size")
211 |
212 | self.image_embeddings = image_embeddings
213 |
214 | def build_seq_embeddings(self):
215 | """Builds the input sequence embeddings.
216 |
217 | Inputs:
218 | self.input_seqs
219 |
220 | Outputs:
221 | self.seq_embeddings
222 | """
223 | with tf.compat.v1.variable_scope("seq_embedding"), tf.device("/cpu:0"):
224 | embedding_map = tf.compat.v1.get_variable(
225 | name="map",
226 | shape=[self.config.vocab_size, self.config.embedding_size],
227 | initializer=self.initializer)
228 | seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
229 |
230 | self.seq_embeddings = seq_embeddings
231 |
232 | def build_model(self):
233 | """Builds the model.
234 |
235 | Inputs:
236 | self.image_embeddings
237 | self.seq_embeddings
238 | self.target_seqs (training and eval only)
239 | self.input_mask (training and eval only)
240 |
241 | Outputs:
242 | self.total_loss (training and eval only)
243 | self.target_cross_entropy_losses (training and eval only)
244 | self.target_cross_entropy_loss_weights (training and eval only)
245 | """
246 | # This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
247 | # modified LSTM in the "Show and Tell" paper has no biases and outputs
248 | # new_c * sigmoid(o).
249 | lstm_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(
250 | num_units=self.config.num_lstm_units, state_is_tuple=True)
251 | if self.mode == "train":
252 | lstm_cell = tf.contrib.rnn.DropoutWrapper(
253 | lstm_cell,
254 | input_keep_prob=self.config.lstm_dropout_keep_prob,
255 | output_keep_prob=self.config.lstm_dropout_keep_prob)
256 |
257 | with tf.compat.v1.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
258 | # Feed the image embeddings to set the initial LSTM state.
259 | zero_state = lstm_cell.zero_state(
260 | batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
261 | _, initial_state = lstm_cell(self.image_embeddings, zero_state)
262 |
263 | # Allow the LSTM variables to be reused.
264 | lstm_scope.reuse_variables()
265 |
266 | if self.mode == "inference":
267 | # In inference mode, use concatenated states for convenient feeding and
268 | # fetching.
269 | tf.concat(axis=1, values=initial_state, name="initial_state")
270 |
271 | # Placeholder for feeding a batch of concatenated states.
272 | state_feed = tf.compat.v1.placeholder(dtype=tf.float32,
273 | shape=[None, sum(lstm_cell.state_size)],
274 | name="state_feed")
275 | state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
276 |
277 | # Run a single LSTM step.
278 | lstm_outputs, state_tuple = lstm_cell(
279 | inputs=tf.squeeze(self.seq_embeddings, axis=[1]),
280 | state=state_tuple)
281 |
282 | # Concatentate the resulting state.
283 | tf.concat(axis=1, values=state_tuple, name="state")
284 | else:
285 | # Run the batch of sequence embeddings through the LSTM.
286 | sequence_length = tf.reduce_sum(self.input_mask, 1)
287 | lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
288 | inputs=self.seq_embeddings,
289 | sequence_length=sequence_length,
290 | initial_state=initial_state,
291 | dtype=tf.float32,
292 | scope=lstm_scope)
293 |
294 | # Stack batches vertically.
295 | lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
296 |
297 | with tf.compat.v1.variable_scope("logits") as logits_scope:
298 | logits = slim.layers.fully_connected(
299 | inputs=lstm_outputs,
300 | num_outputs=self.config.vocab_size,
301 | activation_fn=None,
302 | weights_initializer=self.initializer,
303 | scope=logits_scope)
304 |
305 | if self.mode == "inference":
306 | tf.nn.softmax(logits, name="softmax")
307 | else:
308 | targets = tf.reshape(self.target_seqs, [-1])
309 | weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
310 |
311 | # Compute losses.
312 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
313 | logits=logits)
314 | batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)),
315 | tf.reduce_sum(weights),
316 | name="batch_loss")
317 | tf.losses.add_loss(batch_loss)
318 | total_loss = tf.losses.get_total_loss()
319 |
320 | # Add summaries.
321 | tf.summary.scalar("losses/batch_loss", batch_loss)
322 | tf.summary.scalar("losses/total_loss", total_loss)
323 | for var in tf.trainable_variables():
324 | tf.summary.histogram("parameters/" + var.op.name, var)
325 |
326 | self.total_loss = total_loss
327 | self.target_cross_entropy_losses = losses # Used in evaluation.
328 | self.target_cross_entropy_loss_weights = weights # Used in evaluation.
329 |
330 | def setup_inception_initializer(self):
331 | """Sets up the function to restore inception variables from checkpoint."""
332 | if self.mode != "inference":
333 | # Restore inception variables only.
334 | saver = tf.train.Saver(self.inception_variables)
335 |
336 | def restore_fn(sess):
337 | tf.logging.info("Restoring Inception variables from checkpoint file %s",
338 | self.config.inception_checkpoint_file)
339 | saver.restore(sess, self.config.inception_checkpoint_file)
340 |
341 | self.init_fn = restore_fn
342 |
343 | def setup_global_step(self):
344 | """Sets up the global step Tensor."""
345 | global_step = tf.compat.v1.Variable(
346 | initial_value=0,
347 | name="global_step",
348 | trainable=False,
349 | collections=[tf.compat.v1.GraphKeys.GLOBAL_STEP, tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
350 |
351 | self.global_step = global_step
352 |
353 | def build(self):
354 | """Creates all ops for training and evaluation."""
355 | self.build_inputs()
356 | self.build_image_embeddings()
357 | self.build_seq_embeddings()
358 | self.build_model()
359 | self.setup_inception_initializer()
360 | self.setup_global_step()
361 |
--------------------------------------------------------------------------------
/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png
--------------------------------------------------------------------------------
/docs/swagger-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/docs/swagger-screenshot.png
--------------------------------------------------------------------------------
/max-image-caption-generator.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | apiVersion: v1
18 | kind: Service
19 | metadata:
20 | name: max-image-caption-generator
21 | spec:
22 | selector:
23 | app: max-image-caption-generator
24 | ports:
25 | - port: 5000
26 | type: NodePort
27 | ---
28 | apiVersion: apps/v1
29 | kind: Deployment
30 | metadata:
31 | name: max-image-caption-generator
32 | labels:
33 | app: max-image-caption-generator
34 | spec:
35 | selector:
36 | matchLabels:
37 | app: max-image-caption-generator
38 | replicas: 1
39 | template:
40 | metadata:
41 | labels:
42 | app: max-image-caption-generator
43 | spec:
44 | containers:
45 | - name: max-image-caption-generator
46 | image: quay.io/codait/max-image-caption-generator:latest
47 | ports:
48 | - containerPort: 5000
49 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest==6.1.2
2 | requests==2.25.0
3 | flake8==3.8.4
4 | bandit==1.6.2
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.6.0
2 | tf_slim==1.1.0
3 | Pillow==8.3.1
4 | # NumPy version is dictated by tensorflow
5 |
--------------------------------------------------------------------------------
/samples/README.md:
--------------------------------------------------------------------------------
1 | # Sample Details
2 |
3 | ## Images
4 |
5 | All test images are from [Pexels](https://www.pexels.com) and licensed under the [CC0 License](https://creativecommons.org/publicdomain/zero/1.0/).
6 |
7 | * [`plane.jpg`](https://www.pexels.com/photo/flight-sky-clouds-aircraft-8394/)
8 | * [`soccer.jpeg`](https://www.pexels.com/photo/action-athletes-ball-blur-274422/)
9 | * [`surfing.jpeg`](https://www.pexels.com/photo/action-beach-fun-leisure-416676/)
--------------------------------------------------------------------------------
/samples/plane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/plane.jpg
--------------------------------------------------------------------------------
/samples/soccer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/soccer.jpg
--------------------------------------------------------------------------------
/samples/surfing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/surfing.jpg
--------------------------------------------------------------------------------
/sha512sums.txt:
--------------------------------------------------------------------------------
1 | dd6bcd865046b826bfb9559253c8114aa5db0b07ec6c329d25df712a7a2aaab019d5e82f3b7807dd162ad034ce8e6a0bdd3c54ab38868172cdd1a881faa030d1 assets/checkpoint/checkpoint
2 | eda91f4245224b43dd6c36f16dd49c30bb65d0aea51ca84a13a710bd1b18454e590cd2ff7822afe9e8501a7fa1c0ff9ec98b3dca55ed8415436d95e0f1a48c52 assets/checkpoint/model2.ckpt-2000000.data-00000-of-00001
3 | 39293d83b5d1988537e44810027bb32f7d5f09148da086f4dde1ce1f18f8a53cc8a67cb415e3ae151d618f15975af796dec2cfbf769bb36168ced81494f90dee assets/checkpoint/model2.ckpt-2000000.index
4 | acc365f7f48a67a0eee1815dd64bbdad960a7d371cee10e99b295510cae95932567e6af4faafcfb01092ff63c205d9ccf5702fbbc2d2c9c7f59445f5f098e60d assets/checkpoint/model2.ckpt-2000000.meta
5 |
--------------------------------------------------------------------------------
/tests/surfing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/tests/surfing.jpg
--------------------------------------------------------------------------------
/tests/surfing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/tests/surfing.png
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import pytest
18 | import requests
19 |
20 |
21 | def test_swagger():
22 |
23 | model_endpoint = 'http://localhost:5000/swagger.json'
24 |
25 | r = requests.get(url=model_endpoint)
26 | assert r.status_code == 200
27 | assert r.headers['Content-Type'] == 'application/json'
28 |
29 | json = r.json()
30 | assert 'swagger' in json
31 | assert json.get('info') and json.get('info').get('title') == 'MAX Image Caption Generator'
32 |
33 |
34 | def test_metadata():
35 |
36 | model_endpoint = 'http://localhost:5000/model/metadata'
37 |
38 | r = requests.get(url=model_endpoint)
39 | assert r.status_code == 200
40 |
41 | metadata = r.json()
42 | assert metadata['id'] == 'max-image-caption-generator'
43 | assert metadata['name'] == 'MAX Image Caption Generator'
44 | assert metadata['description'] == 'im2txt TensorFlow model trained on MSCOCO'
45 | assert metadata['license'] == 'Apache 2.0'
46 | assert metadata['type'] == 'Image-to-Text Translation'
47 | assert 'max-image-caption-generator' in metadata['source']
48 |
49 |
50 | def _check_response(r):
51 | caption_text = 'a man riding a wave on top of a surfboard .'
52 | assert r.status_code == 200
53 | response = r.json()
54 | assert response['status'] == 'ok'
55 | assert response['predictions'][0]['caption'] == caption_text
56 |
57 |
58 | def test_predict():
59 | model_endpoint = 'http://localhost:5000/model/predict'
60 | formats = ['jpg', 'png']
61 | file_path = 'tests/surfing.{}'
62 |
63 | for f in formats:
64 | p = file_path.format(f)
65 | with open(p, 'rb') as file:
66 | file_form = {'image': (p, file, 'image/{}'.format(f))}
67 | r = requests.post(url=model_endpoint, files=file_form)
68 | _check_response(r)
69 |
70 |
71 | if __name__ == '__main__':
72 | pytest.main([__file__])
73 |
--------------------------------------------------------------------------------