├── .bandit ├── .dockerignore ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── README.md ├── api ├── __init__.py ├── metadata.py └── predict.py ├── app.py ├── assets ├── README.md └── word_counts.txt ├── config.py ├── core ├── __init__.py ├── configuration.py ├── inference_utils │ ├── BUILD │ ├── __init__.py │ ├── caption_generator.py │ ├── inference_wrapper_base.py │ └── vocabulary.py ├── inference_wrapper.py ├── model.py ├── ops │ ├── BUILD │ ├── __init__.py │ ├── image_embedding.py │ ├── image_processing.py │ └── inputs.py └── show_and_tell_model.py ├── docs ├── deploy-max-to-ibm-cloud-with-kubernetes-button.png └── swagger-screenshot.png ├── max-image-caption-generator.yaml ├── requirements-test.txt ├── requirements.txt ├── samples ├── README.md ├── plane.jpg ├── soccer.jpg └── surfing.jpg ├── sha512sums.txt └── tests ├── surfing.jpg ├── surfing.png └── test.py /.bandit: -------------------------------------------------------------------------------- 1 | [bandit] 2 | exclude: /tests,/training 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyCharm 29 | .idea/ 30 | *.iml 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | /.pytest_cache/ 107 | .gitignore 108 | .git/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyCharm 29 | .idea/ 30 | *.iml 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | /.pytest_cache/ 107 | 108 | # Ignore Mac DS_Store files 109 | .DS_Store 110 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | language: python 18 | 19 | python: 20 | - 3.6 21 | 22 | services: 23 | - docker 24 | 25 | install: 26 | - docker build -t max-image-caption-generator . 27 | - docker run -it --rm -d -p 5000:5000 max-image-caption-generator 28 | - pip install -r requirements-test.txt 29 | 30 | before_script: 31 | - flake8 . --max-line-length=127 32 | - bandit -r . 33 | - sleep 30 34 | 35 | script: 36 | - pytest tests/test.py 37 | 38 | ### Temporarily disable downstream tests ### 39 | #jobs: 40 | # include: 41 | # - stage: test 42 | # script: 43 | # - pytest tests/test.py 44 | # - stage: trigger downstream 45 | # jdk: oraclejdk8 46 | # script: | 47 | # echo "TRAVIS_BRANCH=$TRAVIS_BRANCH TRAVIS_PULL_REQUEST=$TRAVIS_PULL_REQUEST" 48 | # if [[ ($TRAVIS_BRANCH == master) && 49 | # ($TRAVIS_PULL_REQUEST == false) ]] ; then 50 | # curl -LO --retry 3 https://raw.github.com/mernst/plume-lib/master/bin/trigger-travis.sh 51 | # sh trigger-travis.sh IBM MAX-Image-Caption-Generator-Web-App $TRAVIS_ACCESS_TOKEN 52 | # fi 53 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | FROM quay.io/codait/max-base:v1.5.0 18 | 19 | ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0 20 | ARG model_file=assets.tar.gz 21 | 22 | RUN wget -nv --show-progress --progress=bar:force:noscroll ${model_bucket}/${model_file} --output-document=assets/${model_file} && \ 23 | tar -x -C assets/ -f assets/${model_file} -v && rm assets/${model_file} 24 | 25 | COPY requirements.txt . 26 | RUN pip install -r requirements.txt 27 | 28 | COPY . . 29 | 30 | # check file integrity 31 | RUN sha512sum -c sha512sums.txt 32 | 33 | EXPOSE 5000 34 | 35 | CMD python app.py 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/IBM/MAX-Image-Caption-Generator.svg?branch=master)](https://travis-ci.org/IBM/MAX-Image-Caption-Generator) [![Website Status](https://img.shields.io/website/http/max-image-caption-generator.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud/swagger.json.svg?label=api+demo)](http://max-image-caption-generator.codait-prod-41208c73af8fca213512856c7a09db52-0000.us-east.containers.appdomain.cloud) 2 | 3 | [](http://ibm.biz/max-to-ibm-cloud-tutorial) 4 | 5 | # IBM Developer Model Asset Exchange: Image Caption Generator 6 | 7 | This repository contains code to instantiate and deploy an image caption generation model. This model generates captions from a fixed vocabulary that describe the contents of images in the [COCO Dataset](http://cocodataset.org/#home). The model consists of an _encoder_ model - a deep convolutional net using the Inception-v3 architecture trained on [ImageNet-2012 data](http://www.image-net.org/challenges/LSVRC/2012/) - and a _decoder_ model - an LSTM network that is trained conditioned on the encoding from the image _encoder_ model. The input to the model is an image, and the output is a sentence describing the image content. 8 | 9 | The model is based on the [Show and Tell Image Caption Generator Model](https://github.com/tensorflow/models/tree/master/research/im2txt). The checkpoint files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0/assets.tar.gz). The code in this repository deploys the model as a web service in a Docker container. This repository was developed as part of the [IBM Code Model Asset Exchange](https://developer.ibm.com/code/exchanges/models/). 10 | 11 | ## Model Metadata 12 | | Domain | Application | Industry | Framework | Training Data | Input Data Format | 13 | | ------------- | -------- | -------- | --------- | --------- | -------------- | 14 | | Vision | Image Caption Generator | General | TensorFlow | [COCO](http://cocodataset.org/#home) | Images | 15 | 16 | ## References 17 | * _O. Vinyals, A. Toshev, S. Bengio, D. Erhan._, ["Show and Tell: Lessons learned from the 2015 MSCOCO Image Captioning Challenge"](https://doi.org/10.1109/TPAMI.2016.2587640), IEEE transactions on Pattern Analysis and Machine Intelligence, 2017. 18 | * [im2txt TensorFlow Model GitHub Page](https://github.com/tensorflow/models/tree/master/research/im2txt) 19 | * [COCO Dataset Project Page](http://cocodataset.org/#home) 20 | 21 | ## Licenses 22 | 23 | | Component | License | Link | 24 | | ------------- | -------- | -------- | 25 | | This repository | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [LICENSE](LICENSE) | 26 | | Model Weights | [MIT](https://opensource.org/licenses/MIT) | [Pretrained Show and Tell Model](https://github.com/KranthiGV/Pretrained-Show-and-Tell-model) | 27 | | Model Code (3rd party) | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) | [im2txt](https://github.com/tensorflow/models/tree/master/research/im2txt) | 28 | | Test assets | Various | [Sample README](samples/README.md) | 29 | 30 | ## Pre-requisites: 31 | 32 | * `docker`: The [Docker](https://www.docker.com/) command-line interface. Follow the [installation instructions](https://docs.docker.com/install/) for your system. 33 | * The minimum recommended resources for this model is 2GB Memory and 2 CPUs. 34 | * If you are on x86-64/AMD64, your CPU must support [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) at the minimum. 35 | 36 | # Deployment options 37 | 38 | * [Deploy from Quay](#deploy-from-quay) 39 | * [Deploy on Red Hat OpenShift](#deploy-on-red-hat-openshift) 40 | * [Deploy on Kubernetes](#deploy-on-kubernetes) 41 | * [Run Locally](#run-locally) 42 | 43 | ## Deploy from Quay 44 | 45 | To run the docker image, which automatically starts the model serving API, run: 46 | 47 | ``` 48 | $ docker run -it -p 5000:5000 quay.io/codait/max-image-caption-generator 49 | ``` 50 | 51 | This will pull a pre-built image from the Quay.io container registry (or use an existing image if already cached locally) and run it. 52 | If you'd rather checkout and build the model locally you can follow the [run locally](#run-locally) steps below. 53 | 54 | ## Deploy on Red Hat OpenShift 55 | 56 | You can deploy the model-serving microservice on Red Hat OpenShift by following the instructions for the OpenShift web console or the OpenShift Container Platform CLI [in this tutorial](https://developer.ibm.com/tutorials/deploy-a-model-asset-exchange-microservice-on-red-hat-openshift/), specifying `quay.io/codait/max-image-caption-generator` as the image name. 57 | 58 | ## Deploy on Kubernetes 59 | 60 | You can also deploy the model on Kubernetes using the latest docker image on Quay. 61 | 62 | On your Kubernetes cluster, run the following commands: 63 | 64 | ``` 65 | $ kubectl apply -f https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/master/max-image-caption-generator.yaml 66 | ``` 67 | 68 | The model will be available internally at port `5000`, but can also be accessed externally through the `NodePort`. 69 | 70 | A more elaborate tutorial on how to deploy this MAX model to production on [IBM Cloud](https://ibm.biz/Bdz2XM) can be found [here](http://ibm.biz/max-to-ibm-cloud-tutorial). 71 | 72 | ## Run Locally 73 | 74 | 1. [Build the Model](#1-build-the-model) 75 | 2. [Deploy the Model](#2-deploy-the-model) 76 | 3. [Use the Model](#3-use-the-model) 77 | 4. [Development](#4-development) 78 | 5. [Cleanup](#5-cleanup) 79 | 80 | ### 1. Build the Model 81 | 82 | Clone this repository locally. In a terminal, run the following command: 83 | 84 | ``` 85 | $ git clone https://github.com/IBM/MAX-Image-Caption-Generator.git 86 | ``` 87 | 88 | Change directory into the repository base folder: 89 | 90 | ``` 91 | $ cd MAX-Image-Caption-Generator 92 | ``` 93 | 94 | To build the docker image locally, run: 95 | 96 | ``` 97 | $ docker build -t max-image-caption-generator . 98 | ``` 99 | 100 | All required model assets will be downloaded during the build process. _Note_ that currently this docker image is CPU only (we will add support for GPU images later). 101 | 102 | ### 2. Deploy the Model 103 | 104 | To run the docker image, which automatically starts the model serving API, run: 105 | 106 | ``` 107 | $ docker run -it -p 5000:5000 max-image-caption-generator 108 | ``` 109 | 110 | ### 3. Use the Model 111 | 112 | The API server automatically generates an interactive Swagger documentation page. Go to `http://localhost:5000` to load it. From there you can explore the API and also create test requests. 113 | 114 | Use the `model/predict` endpoint to load a test file and get captions for the image from the API. 115 | 116 | ![pic](/docs/swagger-screenshot.png "Swagger Screenshot") 117 | 118 | You can also test it on the command line, for example: 119 | 120 | ``` 121 | $ curl -F "image=@samples/surfing.jpg" -X POST http://localhost:5000/model/predict 122 | ``` 123 | 124 | ```json 125 | { 126 | "status": "ok", 127 | "predictions": [ 128 | { 129 | "index": "0", 130 | "caption": "a man riding a wave on top of a surfboard .", 131 | "probability": 0.038827644239537 132 | }, 133 | { 134 | "index": "1", 135 | "caption": "a person riding a surf board on a wave", 136 | "probability": 0.017933410519265 137 | }, 138 | { 139 | "index": "2", 140 | "caption": "a man riding a wave on a surfboard in the ocean .", 141 | "probability": 0.0056628732021868 142 | } 143 | ] 144 | } 145 | ``` 146 | 147 | ### 4. Development 148 | 149 | To run the Flask API app in debug mode, edit `config.py` to set `DEBUG = True` under the application settings. You will then need to rebuild the docker image (see [step 1](#1-build-the-model)). 150 | 151 | ### 5. Cleanup 152 | 153 | To stop the Docker container, type `CTRL` + `C` in your terminal. 154 | 155 | ## Links 156 | 157 | * [Image Caption Generator Web App](https://developer.ibm.com/patterns/create-a-web-app-to-interact-with-machine-learning-generated-image-captions): A reference application created by the IBM CODAIT team that uses the Image Caption Generator 158 | 159 | ## Resources and Contributions 160 | 161 | If you are interested in contributing to the Model Asset Exchange project or have any queries, please follow the instructions [here](https://github.com/CODAIT/max-central-repo). 162 | -------------------------------------------------------------------------------- /api/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .metadata import ModelMetadataAPI # noqa 18 | from .predict import ModelPredictAPI # noqa 19 | -------------------------------------------------------------------------------- /api/metadata.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from config import MODEL_META_DATA 18 | from maxfw.core import MAX_API, MetadataAPI, METADATA_SCHEMA 19 | 20 | 21 | class ModelMetadataAPI(MetadataAPI): 22 | 23 | @MAX_API.marshal_with(METADATA_SCHEMA) 24 | def get(self): 25 | """Return the metadata associated with the model""" 26 | return MODEL_META_DATA 27 | -------------------------------------------------------------------------------- /api/predict.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from maxfw.core import MAX_API, PredictAPI 18 | from core.model import ModelWrapper 19 | 20 | from flask import abort 21 | from flask_restx import fields 22 | from werkzeug.datastructures import FileStorage 23 | 24 | 25 | # Set up parser for image input data 26 | image_parser = MAX_API.parser() 27 | image_parser.add_argument('image', type=FileStorage, location='files', required=True, help="An image file") 28 | 29 | label_prediction = MAX_API.model('LabelPrediction', { 30 | 'index': fields.String(required=False, description='Labels ranked by highest probability'), 31 | 'caption': fields.String(required=True, description='Caption generated by image'), 32 | 'probability': fields.Float(required=True, description="Probability of the caption") 33 | }) 34 | 35 | predict_response = MAX_API.model('ModelPredictResponse', { 36 | 'status': fields.String(required=True, description='Response status message'), 37 | 'predictions': fields.List(fields.Nested(label_prediction), description='Predicted captions and probabilities') 38 | }) 39 | 40 | 41 | class ModelPredictAPI(PredictAPI): 42 | 43 | model_wrapper = ModelWrapper() 44 | 45 | @MAX_API.doc('predict') 46 | @MAX_API.expect(image_parser) 47 | @MAX_API.marshal_with(predict_response) 48 | def post(self): 49 | """Make a prediction given input data""" 50 | 51 | result = {'status': 'error'} 52 | args = image_parser.parse_args() 53 | if not args['image'].mimetype.endswith(('jpg', 'jpeg', 'png')): 54 | abort(400, 'Invalid file type/extension. Please provide an image in JPEG or PNG format.') 55 | image_data = args['image'].read() 56 | 57 | preds = self.model_wrapper.predict(image_data) 58 | 59 | label_preds = [{'index': p[0], 'caption': p[1], 'probability': p[2]} for p in [x for x in preds]] 60 | result['predictions'] = label_preds 61 | result['status'] = 'ok' 62 | 63 | return result 64 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from maxfw.core import MAXApp 18 | from api import ModelMetadataAPI, ModelPredictAPI 19 | from config import API_TITLE, API_DESC, API_VERSION 20 | 21 | max_app = MAXApp(API_TITLE, API_DESC, API_VERSION) 22 | max_app.add_api(ModelMetadataAPI, '/metadata') 23 | max_app.add_api(ModelPredictAPI, '/predict') 24 | max_app.run() 25 | -------------------------------------------------------------------------------- /assets/README.md: -------------------------------------------------------------------------------- 1 | # Asset Details 2 | 3 | ## Model files 4 | 5 | Model files are from the [Pretrained Show and Tell Model](https://github.com/KranthiGV/Pretrained-Show-and-Tell-model), where they are available under a [MIT License](https://opensource.org/licenses/MIT). 6 | 7 | _Note_ the model files are hosted on [IBM Cloud Object Storage](https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0/assets.tar.gz). 8 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # Application settings 18 | 19 | # Flask settings 20 | DEBUG = False 21 | 22 | # Flask-restplus settings 23 | RESTPLUS_MASK_SWAGGER = False 24 | SWAGGER_UI_DOC_EXPANSION = 'none' 25 | 26 | # API metadata 27 | API_TITLE = 'MAX Image Caption Generator' 28 | API_DESC = 'Generate captions that describe the contents of images.' 29 | API_VERSION = '1.1.0' 30 | 31 | # default model 32 | MODEL_NAME = 'im2txt' 33 | DEFAULT_MODEL_PATH = 'assets/checkpoint/model2.ckpt-2000000' 34 | VOCAB_FILE = './assets/word_counts.txt' 35 | # for image models, may not be required 36 | MODEL_INPUT_IMG_SIZE = (299, 299) 37 | MODEL_LICENSE = 'Apache 2.0' 38 | 39 | MODEL_META_DATA = { 40 | 'id': API_TITLE.lower().replace(' ', '-'), 41 | 'name': API_TITLE, 42 | 'description': '{} TensorFlow model trained on MSCOCO'.format(MODEL_NAME), 43 | 'type': 'Image-to-Text Translation', 44 | 'license': MODEL_LICENSE, 45 | 'source': 'https://developer.ibm.com/exchanges/models/all/max-image-caption-generator/' 46 | } 47 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /core/configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Image-to-text model and training configurations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | class ModelConfig(object): 24 | """Wrapper class for model hyperparameters.""" 25 | 26 | def __init__(self): 27 | """Sets the default model hyperparameters.""" 28 | # File pattern of sharded TFRecord file containing SequenceExample protos. 29 | # Must be provided in training and evaluation modes. 30 | self.input_file_pattern = None 31 | 32 | # Image format ("jpeg" or "png"). 33 | self.image_format = "jpeg" 34 | 35 | # Approximate number of values per input shard. Used to ensure sufficient 36 | # mixing between shards in training. 37 | self.values_per_input_shard = 2300 38 | # Minimum number of shards to keep in the input queue. 39 | self.input_queue_capacity_factor = 2 40 | # Number of threads for prefetching SequenceExample protos. 41 | self.num_input_reader_threads = 1 42 | 43 | # Name of the SequenceExample context feature containing image data. 44 | self.image_feature_name = "image/data" 45 | # Name of the SequenceExample feature list containing integer captions. 46 | self.caption_feature_name = "image/caption_ids" 47 | 48 | # Number of unique words in the vocab (plus 1, for ). 49 | # The default value is larger than the expected actual vocab size to allow 50 | # for differences between tokenizer versions used in preprocessing. There is 51 | # no harm in using a value greater than the actual vocab size, but using a 52 | # value less than the actual vocab size will result in an error. 53 | self.vocab_size = 12000 54 | 55 | # Number of threads for image preprocessing. Should be a multiple of 2. 56 | self.num_preprocess_threads = 4 57 | 58 | # Batch size. 59 | self.batch_size = 32 60 | 61 | # File containing an Inception v3 checkpoint to initialize the variables 62 | # of the Inception model. Must be provided when starting training for the 63 | # first time. 64 | self.inception_checkpoint_file = None 65 | 66 | # Dimensions of Inception v3 input images. 67 | self.image_height = 299 68 | self.image_width = 299 69 | 70 | # Scale used to initialize model variables. 71 | self.initializer_scale = 0.08 72 | 73 | # LSTM input and output dimensionality, respectively. 74 | self.embedding_size = 512 75 | self.num_lstm_units = 512 76 | 77 | # If < 1.0, the dropout keep probability applied to LSTM variables. 78 | self.lstm_dropout_keep_prob = 0.7 79 | 80 | 81 | class TrainingConfig(object): 82 | """Wrapper class for training hyperparameters.""" 83 | 84 | def __init__(self): 85 | """Sets the default training hyperparameters.""" 86 | # Number of examples per epoch of training data. 87 | self.num_examples_per_epoch = 586363 88 | 89 | # Optimizer for training the model. 90 | self.optimizer = "SGD" 91 | 92 | # Learning rate for the initial phase of training. 93 | self.initial_learning_rate = 2.0 94 | self.learning_rate_decay_factor = 0.5 95 | self.num_epochs_per_decay = 8.0 96 | 97 | # Learning rate when fine tuning the Inception v3 parameters. 98 | self.train_inception_learning_rate = 0.0005 99 | 100 | # If not None, clip gradients to this value. 101 | self.clip_gradients = 5.0 102 | 103 | # How many model checkpoints to keep. 104 | self.max_checkpoints_to_keep = 5 105 | -------------------------------------------------------------------------------- /core/inference_utils/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//im2txt:internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | py_library( 8 | name = "inference_wrapper_base", 9 | srcs = ["inference_wrapper_base.py"], 10 | srcs_version = "PY2AND3", 11 | ) 12 | 13 | py_library( 14 | name = "vocabulary", 15 | srcs = ["vocabulary.py"], 16 | srcs_version = "PY2AND3", 17 | ) 18 | 19 | py_library( 20 | name = "caption_generator", 21 | srcs = ["caption_generator.py"], 22 | srcs_version = "PY2AND3", 23 | ) 24 | 25 | py_test( 26 | name = "caption_generator_test", 27 | srcs = ["caption_generator_test.py"], 28 | deps = [ 29 | ":caption_generator", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /core/inference_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/core/inference_utils/__init__.py -------------------------------------------------------------------------------- /core/inference_utils/caption_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Class for generating captions from an image-to-text model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import heapq 22 | import math 23 | 24 | import numpy as np 25 | 26 | 27 | class Caption(object): 28 | """Represents a complete or partial caption.""" 29 | 30 | def __init__(self, sentence, state, logprob, score, metadata=None): 31 | """Initializes the Caption. 32 | 33 | Args: 34 | sentence: List of word ids in the caption. 35 | state: Model state after generating the previous word. 36 | logprob: Log-probability of the caption. 37 | score: Score of the caption. 38 | metadata: Optional metadata associated with the partial sentence. If not 39 | None, a list of strings with the same length as 'sentence'. 40 | """ 41 | self.sentence = sentence 42 | self.state = state 43 | self.logprob = logprob 44 | self.score = score 45 | self.metadata = metadata 46 | 47 | def __cmp__(self, other): 48 | """Compares Captions by score.""" 49 | if not isinstance(other, Caption): 50 | raise ValueError("'%s' isn't an instance of Caption" % other) 51 | if self.score == other.score: 52 | return 0 53 | elif self.score < other.score: 54 | return -1 55 | else: 56 | return 1 57 | 58 | # For Python 3 compatibility (__cmp__ is deprecated). 59 | def __lt__(self, other): 60 | if not isinstance(other, Caption): 61 | raise ValueError("'%s' isn't an instance of Caption" % other) 62 | return self.score < other.score 63 | 64 | # Also for Python 3 compatibility. 65 | def __eq__(self, other): 66 | if not isinstance(other, Caption): 67 | raise ValueError("'%s' isn't an instance of Caption" % other) 68 | return self.score == other.score 69 | 70 | 71 | class TopN(object): 72 | """Maintains the top n elements of an incrementally provided set.""" 73 | 74 | def __init__(self, n): 75 | self._n = n 76 | self._data = [] 77 | 78 | def size(self): 79 | if self._data is None: 80 | raise ValueError("'self._data' is None") 81 | return len(self._data) 82 | 83 | def push(self, x): 84 | """Pushes a new element.""" 85 | if self._data is None: 86 | raise ValueError("'self._data' is none") 87 | if len(self._data) < self._n: 88 | heapq.heappush(self._data, x) 89 | else: 90 | heapq.heappushpop(self._data, x) 91 | 92 | def extract(self, sort=False): 93 | """Extracts all elements from the TopN. This is a destructive operation. 94 | 95 | The only method that can be called immediately after extract() is reset(). 96 | 97 | Args: 98 | sort: Whether to return the elements in descending sorted order. 99 | 100 | Returns: 101 | A list of data; the top n elements provided to the set. 102 | """ 103 | if self._data is None: 104 | raise ValueError("'self._data' is None") 105 | data = self._data 106 | self._data = None 107 | if sort: 108 | data.sort(reverse=True) 109 | return data 110 | 111 | def reset(self): 112 | """Returns the TopN to an empty state.""" 113 | self._data = [] 114 | 115 | 116 | class CaptionGenerator(object): 117 | """Class to generate captions from an image-to-text model.""" 118 | 119 | def __init__(self, 120 | model, 121 | vocab, 122 | beam_size=3, 123 | max_caption_length=20, 124 | length_normalization_factor=0.0): 125 | """Initializes the generator. 126 | 127 | Args: 128 | model: Object encapsulating a trained image-to-text model. Must have 129 | methods feed_image() and inference_step(). For example, an instance of 130 | InferenceWrapperBase. 131 | vocab: A Vocabulary object. 132 | beam_size: Beam size to use when generating captions. 133 | max_caption_length: The maximum caption length before stopping the search. 134 | length_normalization_factor: If != 0, a number x such that captions are 135 | scored by logprob/length^x, rather than logprob. This changes the 136 | relative scores of captions depending on their lengths. For example, if 137 | x > 0 then longer captions will be favored. 138 | """ 139 | self.vocab = vocab 140 | self.model = model 141 | 142 | self.beam_size = beam_size 143 | self.max_caption_length = max_caption_length 144 | self.length_normalization_factor = length_normalization_factor 145 | 146 | def beam_search(self, sess, encoded_image): 147 | """Runs beam search caption generation on a single image. 148 | 149 | Args: 150 | sess: TensorFlow Session object. 151 | encoded_image: An encoded image string. 152 | 153 | Returns: 154 | A list of Caption sorted by descending score. 155 | """ 156 | # Feed in the image to get the initial state. 157 | initial_state = self.model.feed_image(sess, encoded_image) 158 | 159 | initial_beam = Caption( 160 | sentence=[self.vocab.start_id], 161 | state=initial_state[0], 162 | logprob=0.0, 163 | score=0.0, 164 | metadata=[""]) 165 | partial_captions = TopN(self.beam_size) 166 | partial_captions.push(initial_beam) 167 | complete_captions = TopN(self.beam_size) 168 | 169 | # Run beam search. 170 | for _ in range(self.max_caption_length - 1): 171 | partial_captions_list = partial_captions.extract() 172 | partial_captions.reset() 173 | input_feed = np.array([c.sentence[-1] for c in partial_captions_list]) 174 | state_feed = np.array([c.state for c in partial_captions_list]) 175 | 176 | softmax, new_states, metadata = self.model.inference_step(sess, 177 | input_feed, 178 | state_feed) 179 | 180 | for i, partial_caption in enumerate(partial_captions_list): 181 | word_probabilities = softmax[i] 182 | state = new_states[i] 183 | # For this partial caption, get the beam_size most probable next words. 184 | words_and_probs = list(enumerate(word_probabilities)) 185 | words_and_probs.sort(key=lambda x: -x[1]) 186 | words_and_probs = words_and_probs[0:self.beam_size] 187 | # Each next word gives a new partial caption. 188 | for w, p in words_and_probs: 189 | if p < 1e-12: 190 | continue # Avoid log(0). 191 | sentence = partial_caption.sentence + [w] 192 | logprob = partial_caption.logprob + math.log(p) 193 | score = logprob 194 | if metadata: 195 | metadata_list = partial_caption.metadata + [metadata[i]] 196 | else: 197 | metadata_list = None 198 | if w == self.vocab.end_id: 199 | if self.length_normalization_factor > 0: 200 | score /= len(sentence) ** self.length_normalization_factor 201 | beam = Caption(sentence, state, logprob, score, metadata_list) 202 | complete_captions.push(beam) 203 | else: 204 | beam = Caption(sentence, state, logprob, score, metadata_list) 205 | partial_captions.push(beam) 206 | if partial_captions.size() == 0: 207 | # We have run out of partial candidates; happens when beam_size = 1. 208 | break 209 | 210 | # If we have no complete captions then fall back to the partial captions. 211 | # But never output a mixture of complete and partial captions because a 212 | # partial caption could have a higher score than all the complete captions. 213 | if not complete_captions.size(): 214 | complete_captions = partial_captions 215 | 216 | return complete_captions.extract(sort=True) 217 | -------------------------------------------------------------------------------- /core/inference_utils/inference_wrapper_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base wrapper class for performing inference with an image-to-text model. 16 | 17 | Subclasses must implement the following methods: 18 | 19 | build_model(): 20 | Builds the model for inference and returns the model object. 21 | 22 | feed_image(): 23 | Takes an encoded image and returns the initial model state, where "state" 24 | is a numpy array whose specifics are defined by the subclass, e.g. 25 | concatenated LSTM state. It's assumed that feed_image() will be called 26 | precisely once at the start of inference for each image. Subclasses may 27 | compute and/or save per-image internal context in this method. 28 | 29 | inference_step(): 30 | Takes a batch of inputs and states at a single time-step. Returns the 31 | softmax output corresponding to the inputs, and the new states of the batch. 32 | Optionally also returns metadata about the current inference step, e.g. a 33 | serialized numpy array containing activations from a particular model layer. 34 | 35 | Client usage: 36 | 1. Build the model inference graph via build_graph_from_config() or 37 | build_graph_from_proto(). 38 | 2. Call the resulting restore_fn to load the model checkpoint. 39 | 3. For each image in a batch of images: 40 | a) Call feed_image() once to get the initial state. 41 | b) For each step of caption generation, call inference_step(). 42 | """ 43 | 44 | from __future__ import absolute_import 45 | from __future__ import division 46 | from __future__ import print_function 47 | 48 | import os.path 49 | 50 | import tensorflow as tf 51 | 52 | 53 | # pylint: disable=unused-argument 54 | 55 | 56 | class InferenceWrapperBase(object): 57 | """Base wrapper class for performing inference with an image-to-text model.""" 58 | 59 | def __init__(self): 60 | pass 61 | 62 | def build_model(self, model_config): 63 | """Builds the model for inference. 64 | 65 | Args: 66 | model_config: Object containing configuration for building the model. 67 | 68 | Returns: 69 | model: The model object. 70 | """ 71 | tf.compat.v1.logging.fatal("Please implement build_model in subclass") 72 | 73 | def _create_restore_fn(self, checkpoint_path, saver): 74 | """Creates a function that restores a model from checkpoint. 75 | 76 | Args: 77 | checkpoint_path: Checkpoint file or a directory containing a checkpoint 78 | file. 79 | saver: Saver for restoring variables from the checkpoint file. 80 | 81 | Returns: 82 | restore_fn: A function such that restore_fn(sess) loads model variables 83 | from the checkpoint file. 84 | 85 | Raises: 86 | ValueError: If checkpoint_path does not refer to a checkpoint file or a 87 | directory containing a checkpoint file. 88 | """ 89 | if tf.compat.v1.gfile.IsDirectory(checkpoint_path): 90 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) 91 | if not checkpoint_path: 92 | raise ValueError("No checkpoint file found in: %s" % checkpoint_path) 93 | 94 | def _restore_fn(sess): 95 | tf.compat.v1.logging.info("Loading model from checkpoint: %s", checkpoint_path) 96 | saver.restore(sess, checkpoint_path) 97 | tf.compat.v1.logging.info("Successfully loaded checkpoint: %s", 98 | os.path.basename(checkpoint_path)) 99 | 100 | return _restore_fn 101 | 102 | def build_graph_from_config(self, model_config, checkpoint_path): 103 | """Builds the inference graph from a configuration object. 104 | 105 | Args: 106 | model_config: Object containing configuration for building the model. 107 | checkpoint_path: Checkpoint file or a directory containing a checkpoint 108 | file. 109 | 110 | Returns: 111 | restore_fn: A function such that restore_fn(sess) loads model variables 112 | from the checkpoint file. 113 | """ 114 | tf.compat.v1.logging.info("Building model.") 115 | self.build_model(model_config) 116 | saver = tf.compat.v1.train.Saver() 117 | 118 | return self._create_restore_fn(checkpoint_path, saver) 119 | 120 | def build_graph_from_proto(self, graph_def_file, saver_def_file, 121 | checkpoint_path): 122 | """Builds the inference graph from serialized GraphDef and SaverDef protos. 123 | 124 | Args: 125 | graph_def_file: File containing a serialized GraphDef proto. 126 | saver_def_file: File containing a serialized SaverDef proto. 127 | checkpoint_path: Checkpoint file or a directory containing a checkpoint 128 | file. 129 | 130 | Returns: 131 | restore_fn: A function such that restore_fn(sess) loads model variables 132 | from the checkpoint file. 133 | """ 134 | # Load the Graph. 135 | tf.compat.v1.logging.info("Loading GraphDef from file: %s", graph_def_file) 136 | graph_def = tf.GraphDef() 137 | with tf.compat.v1.gfile.FastGFile(graph_def_file, "rb") as f: 138 | graph_def.ParseFromString(f.read()) 139 | tf.import_graph_def(graph_def, name="") 140 | 141 | # Load the Saver. 142 | tf.compat.v1.logging.info("Loading SaverDef from file: %s", saver_def_file) 143 | saver_def = tf.compat.v1.train.SaverDef() 144 | with tf.compat.v1.gfile.FastGFile(saver_def_file, "rb") as f: 145 | saver_def.ParseFromString(f.read()) 146 | saver = tf.compat.v1.train.Saver(saver_def=saver_def) 147 | 148 | return self._create_restore_fn(checkpoint_path, saver) 149 | 150 | def feed_image(self, sess, encoded_image): 151 | """Feeds an image and returns the initial model state. 152 | 153 | See comments at the top of file. 154 | 155 | Args: 156 | sess: TensorFlow Session object. 157 | encoded_image: An encoded image string. 158 | 159 | Returns: 160 | state: A numpy array of shape [1, state_size]. 161 | """ 162 | tf.compat.v1.logging.fatal("Please implement feed_image in subclass") 163 | 164 | def inference_step(self, sess, input_feed, state_feed): 165 | """Runs one step of inference. 166 | 167 | Args: 168 | sess: TensorFlow Session object. 169 | input_feed: A numpy array of shape [batch_size]. 170 | state_feed: A numpy array of shape [batch_size, state_size]. 171 | 172 | Returns: 173 | softmax_output: A numpy array of shape [batch_size, vocab_size]. 174 | new_state: A numpy array of shape [batch_size, state_size]. 175 | metadata: Optional. If not None, a string containing metadata about the 176 | current inference step (e.g. serialized numpy array containing 177 | activations from a particular model layer.). 178 | """ 179 | tf.compat.v1.logging.fatal("Please implement inference_step in subclass") 180 | 181 | # pylint: enable=unused-argument 182 | -------------------------------------------------------------------------------- /core/inference_utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Vocabulary class for an image-to-text model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | class Vocabulary(object): 25 | """Vocabulary class for an image-to-text model.""" 26 | 27 | def __init__(self, 28 | vocab_file, 29 | start_word="", 30 | end_word="", 31 | unk_word=""): 32 | """Initializes the vocabulary. 33 | 34 | Args: 35 | vocab_file: File containing the vocabulary, where the words are the first 36 | whitespace-separated token on each line (other tokens are ignored) and 37 | the word ids are the corresponding line numbers. 38 | start_word: Special word denoting sentence start. 39 | end_word: Special word denoting sentence end. 40 | unk_word: Special word denoting unknown words. 41 | """ 42 | if not tf.compat.v1.gfile.Exists(vocab_file): 43 | tf.compat.v1.logging.fatal("Vocab file %s not found.", vocab_file) 44 | tf.compat.v1.logging.info("Initializing vocabulary from file: %s", vocab_file) 45 | 46 | with tf.compat.v1.gfile.GFile(vocab_file, mode="r") as f: 47 | reverse_vocab = list(f.readlines()) 48 | reverse_vocab = [line.split()[0] for line in reverse_vocab] 49 | if start_word not in reverse_vocab: 50 | raise ValueError("Start word '%s' is not in 'reserved_vocab" % start_word) 51 | if end_word not in reverse_vocab: 52 | raise ValueError("End word '%s' is not in 'reserved_vocab" % end_word) 53 | if unk_word not in reverse_vocab: 54 | reverse_vocab.append(unk_word) 55 | vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 56 | 57 | tf.compat.v1.logging.info("Created vocabulary with %d words" % len(vocab)) 58 | 59 | self.vocab = vocab # vocab[word] = id 60 | self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word 61 | 62 | # Save special word ids. 63 | self.start_id = vocab[start_word] 64 | self.end_id = vocab[end_word] 65 | self.unk_id = vocab[unk_word] 66 | 67 | def word_to_id(self, word): 68 | """Returns the integer word id of a word string.""" 69 | if word in self.vocab: 70 | return self.vocab[word] 71 | else: 72 | return self.unk_id 73 | 74 | def id_to_word(self, word_id): 75 | """Returns the word string of an integer word id.""" 76 | if word_id >= len(self.reverse_vocab): 77 | return self.reverse_vocab[self.unk_id] 78 | else: 79 | return self.reverse_vocab[word_id] 80 | -------------------------------------------------------------------------------- /core/inference_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model wrapper class for performing inference with a ShowAndTellModel.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from core import show_and_tell_model 23 | from core.inference_utils import inference_wrapper_base 24 | 25 | 26 | class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase): 27 | """Model wrapper class for performing inference with a ShowAndTellModel.""" 28 | 29 | def __init__(self): 30 | super(InferenceWrapper, self).__init__() 31 | 32 | def build_model(self, model_config): 33 | model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference") 34 | model.build() 35 | return model 36 | 37 | def feed_image(self, sess, encoded_image): 38 | initial_state = sess.run(fetches="lstm/initial_state:0", 39 | feed_dict={"image_feed:0": encoded_image}) 40 | return initial_state 41 | 42 | def inference_step(self, sess, input_feed, state_feed): 43 | softmax_output, state_output = sess.run( 44 | fetches=["softmax:0", "lstm/state:0"], 45 | feed_dict={ 46 | "input_feed:0": input_feed, 47 | "lstm/state_feed:0": state_feed, 48 | }) 49 | return softmax_output, state_output, None 50 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from maxfw.model import MAXModelWrapper 18 | 19 | import math 20 | import logging 21 | 22 | import tensorflow as tf 23 | 24 | from core import configuration 25 | from core import inference_wrapper 26 | from core.inference_utils import vocabulary 27 | from core.inference_utils import caption_generator 28 | 29 | from config import DEFAULT_MODEL_PATH, VOCAB_FILE 30 | 31 | logger = logging.getLogger() 32 | 33 | tf.compat.v1.disable_eager_execution() 34 | 35 | 36 | class ModelWrapper(MAXModelWrapper): 37 | 38 | def __init__(self, path=DEFAULT_MODEL_PATH): 39 | # TODO Replace this part with SavedModel 40 | g = tf.Graph() 41 | with g.as_default(): 42 | model = inference_wrapper.InferenceWrapper() 43 | restore_fn = model.build_graph_from_config(configuration.ModelConfig(), 44 | path) 45 | g.finalize() 46 | self.model = model 47 | sess = tf.compat.v1.Session(graph=g) 48 | # Load the model from checkpoint. 49 | restore_fn(sess) 50 | self.sess = sess 51 | 52 | def _predict(self, image_data): 53 | # Create the vocabulary. 54 | vocab = vocabulary.Vocabulary(VOCAB_FILE) 55 | 56 | # Prepare the caption generator. Here we are implicitly using the default 57 | # beam search parameters. See caption_generator.py for a description of the 58 | # available beam search parameters. 59 | generator = caption_generator.CaptionGenerator(self.model, vocab) 60 | 61 | captions = generator.beam_search(self.sess, image_data) 62 | 63 | results = [] 64 | for i, caption in enumerate(captions): 65 | # Ignore begin and end words. 66 | sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] 67 | sentence = " ".join(sentence) 68 | # print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) 69 | results.append((i, sentence, math.exp(caption.logprob))) 70 | 71 | return results 72 | -------------------------------------------------------------------------------- /core/ops/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//im2txt:internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | py_library( 8 | name = "image_processing", 9 | srcs = ["image_processing.py"], 10 | srcs_version = "PY2AND3", 11 | ) 12 | 13 | py_library( 14 | name = "image_embedding", 15 | srcs = ["image_embedding.py"], 16 | srcs_version = "PY2AND3", 17 | ) 18 | 19 | py_test( 20 | name = "image_embedding_test", 21 | size = "small", 22 | srcs = ["image_embedding_test.py"], 23 | deps = [ 24 | ":image_embedding", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "inputs", 30 | srcs = ["inputs.py"], 31 | srcs_version = "PY2AND3", 32 | ) 33 | -------------------------------------------------------------------------------- /core/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/core/ops/__init__.py -------------------------------------------------------------------------------- /core/ops/image_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Image embedding ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | import tf_slim as slim 25 | from tf_slim.nets.inception_v3 import inception_v3_base 26 | 27 | 28 | def inception_v3(images, 29 | trainable=True, 30 | is_training=True, 31 | weight_decay=0.00004, 32 | stddev=0.1, 33 | dropout_keep_prob=0.8, 34 | use_batch_norm=True, 35 | batch_norm_params=None, 36 | add_summaries=True, 37 | scope="InceptionV3"): 38 | """Builds an Inception V3 subgraph for image embeddings. 39 | 40 | Args: 41 | images: A float32 Tensor of shape [batch, height, width, channels]. 42 | trainable: Whether the inception submodel should be trainable or not. 43 | is_training: Boolean indicating training mode or not. 44 | weight_decay: Coefficient for weight regularization. 45 | stddev: The standard deviation of the trunctated normal weight initializer. 46 | dropout_keep_prob: Dropout keep probability. 47 | use_batch_norm: Whether to use batch normalization. 48 | batch_norm_params: Parameters for batch normalization. See 49 | tf.contrib.layers.batch_norm for details. 50 | add_summaries: Whether to add activation summaries. 51 | scope: Optional Variable scope. 52 | 53 | Returns: 54 | end_points: A dictionary of activations from inception_v3 layers. 55 | """ 56 | # Only consider the inception model to be in training mode if it's trainable. 57 | is_inception_model_training = trainable and is_training 58 | 59 | if use_batch_norm: 60 | # Default parameters for batch normalization. 61 | if not batch_norm_params: 62 | batch_norm_params = { 63 | "is_training": is_inception_model_training, 64 | "trainable": trainable, 65 | # Decay for the moving averages. 66 | "decay": 0.9997, 67 | # Epsilon to prevent 0s in variance. 68 | "epsilon": 0.001, 69 | # Collection containing the moving mean and moving variance. 70 | "variables_collections": { 71 | "beta": None, 72 | "gamma": None, 73 | "moving_mean": ["moving_vars"], 74 | "moving_variance": ["moving_vars"], 75 | } 76 | } 77 | else: 78 | batch_norm_params = None 79 | 80 | if trainable: 81 | weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 82 | else: 83 | weights_regularizer = None 84 | 85 | with tf.compat.v1.variable_scope(scope, "InceptionV3", [images]) as scope: 86 | with slim.arg_scope( 87 | [slim.conv2d, slim.fully_connected], 88 | weights_regularizer=weights_regularizer, 89 | trainable=trainable): 90 | with slim.arg_scope( 91 | [slim.conv2d], 92 | weights_initializer=tf.compat.v1.truncated_normal_initializer(stddev=stddev), 93 | activation_fn=tf.nn.relu, 94 | normalizer_fn=slim.batch_norm, 95 | normalizer_params=batch_norm_params): 96 | net, end_points = inception_v3_base(images, scope=scope) 97 | with tf.compat.v1.variable_scope("logits"): 98 | shape = net.get_shape() 99 | net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool") 100 | net = slim.dropout( 101 | net, 102 | keep_prob=dropout_keep_prob, 103 | is_training=is_inception_model_training, 104 | scope="dropout") 105 | net = slim.flatten(net, scope="flatten") 106 | 107 | # Add summaries. 108 | if add_summaries: 109 | for v in end_points.values(): 110 | slim.summarize_activation(v) 111 | 112 | return net 113 | -------------------------------------------------------------------------------- /core/ops/image_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions for image preprocessing.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def distort_image(image, thread_id): 26 | """Perform random distortions on an image. 27 | 28 | Args: 29 | image: A float32 Tensor of shape [height, width, 3] with values in [0, 1). 30 | thread_id: Preprocessing thread id used to select the ordering of color 31 | distortions. There should be a multiple of 2 preprocessing threads. 32 | 33 | Returns: 34 | distorted_image: A float32 Tensor of shape [height, width, 3] with values in 35 | [0, 1]. 36 | """ 37 | # Randomly flip horizontally. 38 | with tf.name_scope("flip_horizontal"): 39 | image = tf.image.random_flip_left_right(image) 40 | 41 | # Randomly distort the colors based on thread id. 42 | color_ordering = thread_id % 2 43 | with tf.name_scope("distort_color"): 44 | if color_ordering == 0: 45 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 46 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 47 | image = tf.image.random_hue(image, max_delta=0.032) 48 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 49 | elif color_ordering == 1: 50 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 51 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 52 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 53 | image = tf.image.random_hue(image, max_delta=0.032) 54 | 55 | # The random_* ops do not necessarily clamp. 56 | image = tf.clip_by_value(image, 0.0, 1.0) 57 | 58 | return image 59 | 60 | 61 | def process_image(encoded_image, 62 | is_training, 63 | height, 64 | width, 65 | resize_height=346, 66 | resize_width=346, 67 | thread_id=0, 68 | image_format="jpeg"): 69 | """Decode an image, resize and apply random distortions. 70 | 71 | In training, images are distorted slightly differently depending on thread_id. 72 | 73 | Args: 74 | encoded_image: String Tensor containing the image. 75 | is_training: Boolean; whether preprocessing for training or eval. 76 | height: Height of the output image. 77 | width: Width of the output image. 78 | resize_height: If > 0, resize height before crop to final dimensions. 79 | resize_width: If > 0, resize width before crop to final dimensions. 80 | thread_id: Preprocessing thread id used to select the ordering of color 81 | distortions. There should be a multiple of 2 preprocessing threads. 82 | image_format: "jpeg" or "png". 83 | 84 | Returns: 85 | A float32 Tensor of shape [height, width, 3] with values in [-1, 1]. 86 | 87 | Raises: 88 | ValueError: If image_format is invalid. 89 | """ 90 | 91 | # Helper function to log an image summary to the visualizer. Summaries are 92 | # only logged in thread 0. 93 | def image_summary(name, image): 94 | if not thread_id: 95 | tf.summary.image(name, tf.expand_dims(image, 0)) 96 | 97 | # Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1). 98 | with tf.name_scope("decode"): 99 | if image_format == "jpeg": 100 | image = tf.image.decode_jpeg(encoded_image, channels=3) 101 | elif image_format == "png": 102 | image = tf.image.decode_png(encoded_image, channels=3) 103 | else: 104 | raise ValueError("Invalid image format: %s" % image_format) 105 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 106 | image_summary("original_image", image) 107 | 108 | # Resize image. 109 | if (resize_height > 0) != (resize_width > 0): 110 | raise ValueError("Invalid resize parameters height: '{0}' width: '{1}'".format(resize_height, resize_width)) 111 | 112 | if resize_height: 113 | image = tf.image.resize(image, 114 | size=[resize_height, resize_width], 115 | method=tf.image.ResizeMethod.BILINEAR) 116 | 117 | # Crop to final dimensions. 118 | if is_training: 119 | image = tf.random_crop(image, [height, width, 3]) 120 | else: 121 | # Central crop, assuming resize_height > height, resize_width > width. 122 | image = tf.image.resize_with_crop_or_pad(image, height, width) 123 | 124 | image_summary("resized_image", image) 125 | 126 | # Randomly distort the image. 127 | if is_training: 128 | image = distort_image(image, thread_id) 129 | 130 | image_summary("final_image", image) 131 | 132 | # Rescale to [-1,1] instead of [0, 1] 133 | image = tf.subtract(image, 0.5) 134 | image = tf.multiply(image, 2.0) 135 | return image 136 | -------------------------------------------------------------------------------- /core/ops/inputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Input ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def parse_sequence_example(serialized, image_feature, caption_feature): 26 | """Parses a tensorflow.SequenceExample into an image and caption. 27 | 28 | Args: 29 | serialized: A scalar string Tensor; a single serialized SequenceExample. 30 | image_feature: Name of SequenceExample context feature containing image 31 | data. 32 | caption_feature: Name of SequenceExample feature list containing integer 33 | captions. 34 | 35 | Returns: 36 | encoded_image: A scalar string Tensor containing a JPEG encoded image. 37 | caption: A 1-D uint64 Tensor with dynamically specified length. 38 | """ 39 | context, sequence = tf.parse_single_sequence_example( 40 | serialized, 41 | context_features={ 42 | image_feature: tf.FixedLenFeature([], dtype=tf.string) 43 | }, 44 | sequence_features={ 45 | caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64), 46 | }) 47 | 48 | encoded_image = context[image_feature] 49 | caption = sequence[caption_feature] 50 | return encoded_image, caption 51 | 52 | 53 | def prefetch_input_data(reader, 54 | file_pattern, 55 | is_training, 56 | batch_size, 57 | values_per_shard, 58 | input_queue_capacity_factor=16, 59 | num_reader_threads=1, 60 | shard_queue_name="filename_queue", 61 | value_queue_name="input_queue"): 62 | """Prefetches string values from disk into an input queue. 63 | 64 | In training the capacity of the queue is important because a larger queue 65 | means better mixing of training examples between shards. The minimum number of 66 | values kept in the queue is values_per_shard * input_queue_capacity_factor, 67 | where input_queue_memory factor should be chosen to trade-off better mixing 68 | with memory usage. 69 | 70 | Args: 71 | reader: Instance of tf.ReaderBase. 72 | file_pattern: Comma-separated list of file patterns (e.g. 73 | /tmp/train_data-?????-of-00100). 74 | is_training: Boolean; whether prefetching for training or eval. 75 | batch_size: Model batch size used to determine queue capacity. 76 | values_per_shard: Approximate number of values per shard. 77 | input_queue_capacity_factor: Minimum number of values to keep in the queue 78 | in multiples of values_per_shard. See comments above. 79 | num_reader_threads: Number of reader threads to fill the queue. 80 | shard_queue_name: Name for the shards filename queue. 81 | value_queue_name: Name for the values input queue. 82 | 83 | Returns: 84 | A Queue containing prefetched string values. 85 | """ 86 | data_files = [] 87 | for pattern in file_pattern.split(","): 88 | data_files.extend(tf.compat.v1.gfile.Glob(pattern)) 89 | if not data_files: 90 | tf.logging.fatal("Found no input files matching %s", file_pattern) 91 | else: 92 | tf.logging.info("Prefetching values from %d files matching %s", 93 | len(data_files), file_pattern) 94 | 95 | if is_training: 96 | filename_queue = tf.train.string_input_producer( 97 | data_files, shuffle=True, capacity=16, name=shard_queue_name) 98 | min_queue_examples = values_per_shard * input_queue_capacity_factor 99 | capacity = min_queue_examples + 100 * batch_size 100 | values_queue = tf.RandomShuffleQueue( 101 | capacity=capacity, 102 | min_after_dequeue=min_queue_examples, 103 | dtypes=[tf.string], 104 | name="random_" + value_queue_name) 105 | else: 106 | filename_queue = tf.train.string_input_producer( 107 | data_files, shuffle=False, capacity=1, name=shard_queue_name) 108 | capacity = values_per_shard + 3 * batch_size 109 | values_queue = tf.FIFOQueue( 110 | capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name) 111 | 112 | enqueue_ops = [] 113 | for _ in range(num_reader_threads): 114 | _, value = reader.read(filename_queue) 115 | enqueue_ops.append(values_queue.enqueue([value])) 116 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner( 117 | values_queue, enqueue_ops)) 118 | tf.summary.scalar( 119 | "queue/%s/fraction_of_%d_full" % (values_queue.name, capacity), 120 | tf.cast(values_queue.size(), tf.float32) * (1. / capacity)) 121 | 122 | return values_queue 123 | 124 | 125 | def batch_with_dynamic_pad(images_and_captions, 126 | batch_size, 127 | queue_capacity, 128 | add_summaries=True): 129 | """Batches input images and captions. 130 | 131 | This function splits the caption into an input sequence and a target sequence, 132 | where the target sequence is the input sequence right-shifted by 1. Input and 133 | target sequences are batched and padded up to the maximum length of sequences 134 | in the batch. A mask is created to distinguish real words from padding words. 135 | 136 | Example: 137 | Actual captions in the batch ('-' denotes padded character): 138 | [ 139 | [ 1 2 3 4 5 ], 140 | [ 1 2 3 4 - ], 141 | [ 1 2 3 - - ], 142 | ] 143 | 144 | input_seqs: 145 | [ 146 | [ 1 2 3 4 ], 147 | [ 1 2 3 - ], 148 | [ 1 2 - - ], 149 | ] 150 | 151 | target_seqs: 152 | [ 153 | [ 2 3 4 5 ], 154 | [ 2 3 4 - ], 155 | [ 2 3 - - ], 156 | ] 157 | 158 | mask: 159 | [ 160 | [ 1 1 1 1 ], 161 | [ 1 1 1 0 ], 162 | [ 1 1 0 0 ], 163 | ] 164 | 165 | Args: 166 | images_and_captions: A list of pairs [image, caption], where image is a 167 | Tensor of shape [height, width, channels] and caption is a 1-D Tensor of 168 | any length. Each pair will be processed and added to the queue in a 169 | separate thread. 170 | batch_size: Batch size. 171 | queue_capacity: Queue capacity. 172 | add_summaries: If true, add caption length summaries. 173 | 174 | Returns: 175 | images: A Tensor of shape [batch_size, height, width, channels]. 176 | input_seqs: An int32 Tensor of shape [batch_size, padded_length]. 177 | target_seqs: An int32 Tensor of shape [batch_size, padded_length]. 178 | mask: An int32 0/1 Tensor of shape [batch_size, padded_length]. 179 | """ 180 | enqueue_list = [] 181 | for image, caption in images_and_captions: 182 | caption_length = tf.shape(caption)[0] 183 | input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0) 184 | 185 | input_seq = tf.slice(caption, [0], input_length) 186 | target_seq = tf.slice(caption, [1], input_length) 187 | indicator = tf.ones(input_length, dtype=tf.int32) 188 | enqueue_list.append([image, input_seq, target_seq, indicator]) 189 | 190 | images, input_seqs, target_seqs, mask = tf.train.batch_join( 191 | enqueue_list, 192 | batch_size=batch_size, 193 | capacity=queue_capacity, 194 | dynamic_pad=True, 195 | name="batch_and_pad") 196 | 197 | if add_summaries: 198 | lengths = tf.add(tf.reduce_sum(mask, 1), 1) 199 | tf.summary.scalar("caption_length/batch_min", tf.reduce_min(lengths)) 200 | tf.summary.scalar("caption_length/batch_max", tf.reduce_max(lengths)) 201 | tf.summary.scalar("caption_length/batch_mean", tf.reduce_mean(lengths)) 202 | 203 | return images, input_seqs, target_seqs, mask 204 | -------------------------------------------------------------------------------- /core/show_and_tell_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Image-to-text implementation based on http://arxiv.org/abs/1411.4555. 17 | 18 | "Show and Tell: A Neural Image Caption Generator" 19 | Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | import tf_slim as slim 28 | 29 | from core.ops import image_embedding 30 | from core.ops import image_processing 31 | from core.ops import inputs as input_ops 32 | 33 | 34 | class ShowAndTellModel(object): 35 | """Image-to-text implementation based on http://arxiv.org/abs/1411.4555. 36 | 37 | "Show and Tell: A Neural Image Caption Generator" 38 | Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan 39 | """ 40 | 41 | def __init__(self, config, mode, train_inception=False): 42 | """Basic setup. 43 | 44 | Args: 45 | config: Object containing configuration parameters. 46 | mode: "train", "eval" or "inference". 47 | train_inception: Whether the inception submodel variables are trainable. 48 | """ 49 | if mode not in ["train", "eval", "inference"]: 50 | raise ValueError("Variable 'mode' is invalid: %s" % mode) 51 | self.config = config 52 | self.mode = mode 53 | self.train_inception = train_inception 54 | 55 | # Reader for the input data. 56 | self.reader = tf.compat.v1.TFRecordReader() 57 | 58 | # To match the "Show and Tell" paper we initialize all variables with a 59 | # random uniform initializer. 60 | self.initializer = tf.random_uniform_initializer( 61 | minval=-self.config.initializer_scale, 62 | maxval=self.config.initializer_scale) 63 | 64 | # A float32 Tensor with shape [batch_size, height, width, channels]. 65 | self.images = None 66 | 67 | # An int32 Tensor with shape [batch_size, padded_length]. 68 | self.input_seqs = None 69 | 70 | # An int32 Tensor with shape [batch_size, padded_length]. 71 | self.target_seqs = None 72 | 73 | # An int32 0/1 Tensor with shape [batch_size, padded_length]. 74 | self.input_mask = None 75 | 76 | # A float32 Tensor with shape [batch_size, embedding_size]. 77 | self.image_embeddings = None 78 | 79 | # A float32 Tensor with shape [batch_size, padded_length, embedding_size]. 80 | self.seq_embeddings = None 81 | 82 | # A float32 scalar Tensor; the total loss for the trainer to optimize. 83 | self.total_loss = None 84 | 85 | # A float32 Tensor with shape [batch_size * padded_length]. 86 | self.target_cross_entropy_losses = None 87 | 88 | # A float32 Tensor with shape [batch_size * padded_length]. 89 | self.target_cross_entropy_loss_weights = None 90 | 91 | # Collection of variables from the inception submodel. 92 | self.inception_variables = [] 93 | 94 | # Function to restore the inception submodel from checkpoint. 95 | self.init_fn = None 96 | 97 | # Global step Tensor. 98 | self.global_step = None 99 | 100 | def is_training(self): 101 | """Returns true if the model is built for training mode.""" 102 | return self.mode == "train" 103 | 104 | def process_image(self, encoded_image, thread_id=0): 105 | """Decodes and processes an image string. 106 | 107 | Args: 108 | encoded_image: A scalar string Tensor; the encoded image. 109 | thread_id: Preprocessing thread id used to select the ordering of color 110 | distortions. 111 | 112 | Returns: 113 | A float32 Tensor of shape [height, width, 3]; the processed image. 114 | """ 115 | return image_processing.process_image(encoded_image, 116 | is_training=self.is_training(), 117 | height=self.config.image_height, 118 | width=self.config.image_width, 119 | thread_id=thread_id, 120 | image_format=self.config.image_format) 121 | 122 | def build_inputs(self): 123 | """Input prefetching, preprocessing and batching. 124 | 125 | Outputs: 126 | self.images 127 | self.input_seqs 128 | self.target_seqs (training and eval only) 129 | self.input_mask (training and eval only) 130 | """ 131 | if self.mode == "inference": 132 | # In inference mode, images and inputs are fed via placeholders. 133 | image_feed = tf.compat.v1.placeholder(dtype=tf.string, shape=[], name="image_feed") 134 | input_feed = tf.compat.v1.placeholder(dtype=tf.int64, 135 | shape=[None], # batch_size 136 | name="input_feed") 137 | 138 | # Process image and insert batch dimensions. 139 | images = tf.expand_dims(self.process_image(image_feed), 0) 140 | input_seqs = tf.expand_dims(input_feed, 1) 141 | 142 | # No target sequences or input mask in inference mode. 143 | target_seqs = None 144 | input_mask = None 145 | else: 146 | # Prefetch serialized SequenceExample protos. 147 | input_queue = input_ops.prefetch_input_data( 148 | self.reader, 149 | self.config.input_file_pattern, 150 | is_training=self.is_training(), 151 | batch_size=self.config.batch_size, 152 | values_per_shard=self.config.values_per_input_shard, 153 | input_queue_capacity_factor=self.config.input_queue_capacity_factor, 154 | num_reader_threads=self.config.num_input_reader_threads) 155 | 156 | # Image processing and random distortion. Split across multiple threads 157 | # with each thread applying a slightly different distortion. 158 | if self.config.num_preprocess_threads % 2 != 0: 159 | raise ValueError("'self.config.num_preprocess_threads' is invalid: %s" % self.config.num_preprocess_threads) 160 | images_and_captions = [] 161 | for thread_id in range(self.config.num_preprocess_threads): 162 | serialized_sequence_example = input_queue.dequeue() 163 | encoded_image, caption = input_ops.parse_sequence_example( 164 | serialized_sequence_example, 165 | image_feature=self.config.image_feature_name, 166 | caption_feature=self.config.caption_feature_name) 167 | image = self.process_image(encoded_image, thread_id=thread_id) 168 | images_and_captions.append([image, caption]) 169 | 170 | # Batch inputs. 171 | queue_capacity = (2 * self.config.num_preprocess_threads * 172 | self.config.batch_size) 173 | images, input_seqs, target_seqs, input_mask = ( 174 | input_ops.batch_with_dynamic_pad(images_and_captions, 175 | batch_size=self.config.batch_size, 176 | queue_capacity=queue_capacity)) 177 | 178 | self.images = images 179 | self.input_seqs = input_seqs 180 | self.target_seqs = target_seqs 181 | self.input_mask = input_mask 182 | 183 | def build_image_embeddings(self): 184 | """Builds the image model subgraph and generates image embeddings. 185 | 186 | Inputs: 187 | self.images 188 | 189 | Outputs: 190 | self.image_embeddings 191 | """ 192 | inception_output = image_embedding.inception_v3( 193 | self.images, 194 | trainable=self.train_inception, 195 | is_training=self.is_training()) 196 | self.inception_variables = tf.compat.v1.get_collection( 197 | tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3") 198 | 199 | # Map inception output into embedding space. 200 | with tf.compat.v1.variable_scope("image_embedding") as scope: 201 | image_embeddings = slim.layers.fully_connected( 202 | inputs=inception_output, 203 | num_outputs=self.config.embedding_size, 204 | activation_fn=None, 205 | weights_initializer=self.initializer, 206 | biases_initializer=None, 207 | scope=scope) 208 | 209 | # Save the embedding size in the graph. 210 | tf.constant(self.config.embedding_size, name="embedding_size") 211 | 212 | self.image_embeddings = image_embeddings 213 | 214 | def build_seq_embeddings(self): 215 | """Builds the input sequence embeddings. 216 | 217 | Inputs: 218 | self.input_seqs 219 | 220 | Outputs: 221 | self.seq_embeddings 222 | """ 223 | with tf.compat.v1.variable_scope("seq_embedding"), tf.device("/cpu:0"): 224 | embedding_map = tf.compat.v1.get_variable( 225 | name="map", 226 | shape=[self.config.vocab_size, self.config.embedding_size], 227 | initializer=self.initializer) 228 | seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs) 229 | 230 | self.seq_embeddings = seq_embeddings 231 | 232 | def build_model(self): 233 | """Builds the model. 234 | 235 | Inputs: 236 | self.image_embeddings 237 | self.seq_embeddings 238 | self.target_seqs (training and eval only) 239 | self.input_mask (training and eval only) 240 | 241 | Outputs: 242 | self.total_loss (training and eval only) 243 | self.target_cross_entropy_losses (training and eval only) 244 | self.target_cross_entropy_loss_weights (training and eval only) 245 | """ 246 | # This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the 247 | # modified LSTM in the "Show and Tell" paper has no biases and outputs 248 | # new_c * sigmoid(o). 249 | lstm_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell( 250 | num_units=self.config.num_lstm_units, state_is_tuple=True) 251 | if self.mode == "train": 252 | lstm_cell = tf.contrib.rnn.DropoutWrapper( 253 | lstm_cell, 254 | input_keep_prob=self.config.lstm_dropout_keep_prob, 255 | output_keep_prob=self.config.lstm_dropout_keep_prob) 256 | 257 | with tf.compat.v1.variable_scope("lstm", initializer=self.initializer) as lstm_scope: 258 | # Feed the image embeddings to set the initial LSTM state. 259 | zero_state = lstm_cell.zero_state( 260 | batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32) 261 | _, initial_state = lstm_cell(self.image_embeddings, zero_state) 262 | 263 | # Allow the LSTM variables to be reused. 264 | lstm_scope.reuse_variables() 265 | 266 | if self.mode == "inference": 267 | # In inference mode, use concatenated states for convenient feeding and 268 | # fetching. 269 | tf.concat(axis=1, values=initial_state, name="initial_state") 270 | 271 | # Placeholder for feeding a batch of concatenated states. 272 | state_feed = tf.compat.v1.placeholder(dtype=tf.float32, 273 | shape=[None, sum(lstm_cell.state_size)], 274 | name="state_feed") 275 | state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1) 276 | 277 | # Run a single LSTM step. 278 | lstm_outputs, state_tuple = lstm_cell( 279 | inputs=tf.squeeze(self.seq_embeddings, axis=[1]), 280 | state=state_tuple) 281 | 282 | # Concatentate the resulting state. 283 | tf.concat(axis=1, values=state_tuple, name="state") 284 | else: 285 | # Run the batch of sequence embeddings through the LSTM. 286 | sequence_length = tf.reduce_sum(self.input_mask, 1) 287 | lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell, 288 | inputs=self.seq_embeddings, 289 | sequence_length=sequence_length, 290 | initial_state=initial_state, 291 | dtype=tf.float32, 292 | scope=lstm_scope) 293 | 294 | # Stack batches vertically. 295 | lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size]) 296 | 297 | with tf.compat.v1.variable_scope("logits") as logits_scope: 298 | logits = slim.layers.fully_connected( 299 | inputs=lstm_outputs, 300 | num_outputs=self.config.vocab_size, 301 | activation_fn=None, 302 | weights_initializer=self.initializer, 303 | scope=logits_scope) 304 | 305 | if self.mode == "inference": 306 | tf.nn.softmax(logits, name="softmax") 307 | else: 308 | targets = tf.reshape(self.target_seqs, [-1]) 309 | weights = tf.to_float(tf.reshape(self.input_mask, [-1])) 310 | 311 | # Compute losses. 312 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, 313 | logits=logits) 314 | batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)), 315 | tf.reduce_sum(weights), 316 | name="batch_loss") 317 | tf.losses.add_loss(batch_loss) 318 | total_loss = tf.losses.get_total_loss() 319 | 320 | # Add summaries. 321 | tf.summary.scalar("losses/batch_loss", batch_loss) 322 | tf.summary.scalar("losses/total_loss", total_loss) 323 | for var in tf.trainable_variables(): 324 | tf.summary.histogram("parameters/" + var.op.name, var) 325 | 326 | self.total_loss = total_loss 327 | self.target_cross_entropy_losses = losses # Used in evaluation. 328 | self.target_cross_entropy_loss_weights = weights # Used in evaluation. 329 | 330 | def setup_inception_initializer(self): 331 | """Sets up the function to restore inception variables from checkpoint.""" 332 | if self.mode != "inference": 333 | # Restore inception variables only. 334 | saver = tf.train.Saver(self.inception_variables) 335 | 336 | def restore_fn(sess): 337 | tf.logging.info("Restoring Inception variables from checkpoint file %s", 338 | self.config.inception_checkpoint_file) 339 | saver.restore(sess, self.config.inception_checkpoint_file) 340 | 341 | self.init_fn = restore_fn 342 | 343 | def setup_global_step(self): 344 | """Sets up the global step Tensor.""" 345 | global_step = tf.compat.v1.Variable( 346 | initial_value=0, 347 | name="global_step", 348 | trainable=False, 349 | collections=[tf.compat.v1.GraphKeys.GLOBAL_STEP, tf.compat.v1.GraphKeys.GLOBAL_VARIABLES]) 350 | 351 | self.global_step = global_step 352 | 353 | def build(self): 354 | """Creates all ops for training and evaluation.""" 355 | self.build_inputs() 356 | self.build_image_embeddings() 357 | self.build_seq_embeddings() 358 | self.build_model() 359 | self.setup_inception_initializer() 360 | self.setup_global_step() 361 | -------------------------------------------------------------------------------- /docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/docs/deploy-max-to-ibm-cloud-with-kubernetes-button.png -------------------------------------------------------------------------------- /docs/swagger-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/docs/swagger-screenshot.png -------------------------------------------------------------------------------- /max-image-caption-generator.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | apiVersion: v1 18 | kind: Service 19 | metadata: 20 | name: max-image-caption-generator 21 | spec: 22 | selector: 23 | app: max-image-caption-generator 24 | ports: 25 | - port: 5000 26 | type: NodePort 27 | --- 28 | apiVersion: apps/v1 29 | kind: Deployment 30 | metadata: 31 | name: max-image-caption-generator 32 | labels: 33 | app: max-image-caption-generator 34 | spec: 35 | selector: 36 | matchLabels: 37 | app: max-image-caption-generator 38 | replicas: 1 39 | template: 40 | metadata: 41 | labels: 42 | app: max-image-caption-generator 43 | spec: 44 | containers: 45 | - name: max-image-caption-generator 46 | image: quay.io/codait/max-image-caption-generator:latest 47 | ports: 48 | - containerPort: 5000 49 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest==6.1.2 2 | requests==2.25.0 3 | flake8==3.8.4 4 | bandit==1.6.2 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.6.0 2 | tf_slim==1.1.0 3 | Pillow==8.3.1 4 | # NumPy version is dictated by tensorflow 5 | -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- 1 | # Sample Details 2 | 3 | ## Images 4 | 5 | All test images are from [Pexels](https://www.pexels.com) and licensed under the [CC0 License](https://creativecommons.org/publicdomain/zero/1.0/). 6 | 7 | * [`plane.jpg`](https://www.pexels.com/photo/flight-sky-clouds-aircraft-8394/) 8 | * [`soccer.jpeg`](https://www.pexels.com/photo/action-athletes-ball-blur-274422/) 9 | * [`surfing.jpeg`](https://www.pexels.com/photo/action-beach-fun-leisure-416676/) -------------------------------------------------------------------------------- /samples/plane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/plane.jpg -------------------------------------------------------------------------------- /samples/soccer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/soccer.jpg -------------------------------------------------------------------------------- /samples/surfing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/samples/surfing.jpg -------------------------------------------------------------------------------- /sha512sums.txt: -------------------------------------------------------------------------------- 1 | dd6bcd865046b826bfb9559253c8114aa5db0b07ec6c329d25df712a7a2aaab019d5e82f3b7807dd162ad034ce8e6a0bdd3c54ab38868172cdd1a881faa030d1 assets/checkpoint/checkpoint 2 | eda91f4245224b43dd6c36f16dd49c30bb65d0aea51ca84a13a710bd1b18454e590cd2ff7822afe9e8501a7fa1c0ff9ec98b3dca55ed8415436d95e0f1a48c52 assets/checkpoint/model2.ckpt-2000000.data-00000-of-00001 3 | 39293d83b5d1988537e44810027bb32f7d5f09148da086f4dde1ce1f18f8a53cc8a67cb415e3ae151d618f15975af796dec2cfbf769bb36168ced81494f90dee assets/checkpoint/model2.ckpt-2000000.index 4 | acc365f7f48a67a0eee1815dd64bbdad960a7d371cee10e99b295510cae95932567e6af4faafcfb01092ff63c205d9ccf5702fbbc2d2c9c7f59445f5f098e60d assets/checkpoint/model2.ckpt-2000000.meta 5 | -------------------------------------------------------------------------------- /tests/surfing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/tests/surfing.jpg -------------------------------------------------------------------------------- /tests/surfing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/MAX-Image-Caption-Generator/a2286854a6d8bc2d3de8d698f44fdf835e24d110/tests/surfing.png -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018-2019 IBM Corp. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import pytest 18 | import requests 19 | 20 | 21 | def test_swagger(): 22 | 23 | model_endpoint = 'http://localhost:5000/swagger.json' 24 | 25 | r = requests.get(url=model_endpoint) 26 | assert r.status_code == 200 27 | assert r.headers['Content-Type'] == 'application/json' 28 | 29 | json = r.json() 30 | assert 'swagger' in json 31 | assert json.get('info') and json.get('info').get('title') == 'MAX Image Caption Generator' 32 | 33 | 34 | def test_metadata(): 35 | 36 | model_endpoint = 'http://localhost:5000/model/metadata' 37 | 38 | r = requests.get(url=model_endpoint) 39 | assert r.status_code == 200 40 | 41 | metadata = r.json() 42 | assert metadata['id'] == 'max-image-caption-generator' 43 | assert metadata['name'] == 'MAX Image Caption Generator' 44 | assert metadata['description'] == 'im2txt TensorFlow model trained on MSCOCO' 45 | assert metadata['license'] == 'Apache 2.0' 46 | assert metadata['type'] == 'Image-to-Text Translation' 47 | assert 'max-image-caption-generator' in metadata['source'] 48 | 49 | 50 | def _check_response(r): 51 | caption_text = 'a man riding a wave on top of a surfboard .' 52 | assert r.status_code == 200 53 | response = r.json() 54 | assert response['status'] == 'ok' 55 | assert response['predictions'][0]['caption'] == caption_text 56 | 57 | 58 | def test_predict(): 59 | model_endpoint = 'http://localhost:5000/model/predict' 60 | formats = ['jpg', 'png'] 61 | file_path = 'tests/surfing.{}' 62 | 63 | for f in formats: 64 | p = file_path.format(f) 65 | with open(p, 'rb') as file: 66 | file_form = {'image': (p, file, 'image/{}'.format(f))} 67 | r = requests.post(url=model_endpoint, files=file_form) 68 | _check_response(r) 69 | 70 | 71 | if __name__ == '__main__': 72 | pytest.main([__file__]) 73 | --------------------------------------------------------------------------------