├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── mmt_release.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Pull Requests 4 | 5 | Please send in fixes or feature additions through Pull Requests. 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution, 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to to see 13 | your current agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Transformers 2 | 3 | This code runs inference with the multimodal transformer models described in "Decoupling the Role of Data, 4 | Attention, and Losses in Multimodal Transformers". Our models can be used to 5 | score if an image-text pair match. Please see our paper for more details. 6 | This code release consists of a colab to extract image and language features 7 | and input them into our transformer models. Transformer models are stored on 8 | tfhub. 9 | 10 | 11 | Please see the tables below for details of models which we have released via tfhub. 12 | 13 | Name | Training Dataset | ITM | MRM | MLM | Heads | Layers | Att. Type | FineTuned | Notes 14 | ------------------------------------ | ----------------------------------- | -------------- | --- | --- | ----- | ------ | --------------------- | --------- | ----- 15 | data_cc (base) | Conceptual Captions | Classification | Y | Y | 12 | 6 | Merged | N | 16 | data_sbu | SBU | Classification | Y | Y | 12 | 6 | Merged | N | 17 | data_vg | Visual Genome | Classification | Y | Y | 12 | 6 | Merged | N | 18 | data_mscoco | MSCOCO | Classification | Y | Y | 12 | 6 | Merged | N | 19 | data_mscoco-narratives | MSCOCO Narratives | Classification | Y | Y | 12 | 6 | Merged | N | 20 | data_oi-narratives | OI Narratives | Classification | Y | Y | 12 | 6 | Merged | N | 21 | data_combined-instance | All (instance sampling) | Classification | Y | Y | 12 | 6 | Merged | N | 22 | data_combined-dataset | All (dataset sampling) | Classification | Y | Y | 12 | 6 | Merged | N | 23 | data_uniter-instance | Uniter datasets (instance sampling) | Classification | Y | Y | 12 | 6 | Merged | N | 24 | data_uniter-dataset | Uniter datasets (dataset sampling) | Classification | Y | Y | 12 | 6 | Merged | N | 25 | data_cc-with-bert | Conceptual Captions | Classification | Y | Y | 12 | 6 | Merged | N | Language initialised with BERT 26 | loss_itm_mrm | Conceptual Captions | Classification | Y | N | 12 | 6 | Merged | N | 27 | loss_itm_mlm | Conceptual Captions | Classification | N | Y | 12 | 6 | Merged | N | 28 | loss_single-modality-contrastive32 | Conceptual Captions | Contrastive | Y | Y | 12 | 6 | Sing. Modality | N | 29 | loss_single-modality-contrastive1024 | Conceptual Captions | Contrastive | Y | Y | 12 | 6 | Sing. Modality | N | 30 | loss_v1-contrastive32 | Conceptual Captions | Contrastive | Y | Y | 12 | 1 | Merged | N | 31 | architecture_heads1-768 | Conceptual Captions | Classification | Y | Y | 1 | 6 | Merged | N | 32 | architecture_heads3-256 | Conceptual Captions | Classification | Y | Y | 3 | 6 | Merged | N | 33 | architecture_heads6-64 | Conceptual Captions | Classification | Y | Y | 6 | 6 | Merged | N | 34 | architecture_heads18-64 | Conceptual Captions | Classification | Y | Y | 18 | 6 | Merged | N | 35 | architecture_vilbert-1block | Conceptual Captions | Classification | Y | Y | 12 | 1 | Merged | N | 36 | architecture_vilbert-2block | Conceptual Captions | Classification | Y | Y | 12 | 2 | Merged | N | 37 | architecture_vilbert-4block | Conceptual Captions | Classification | Y | Y | 12 | 4 | Merged | N | 38 | architecture_vilbert-12block | Conceptual Captions | Classification | Y | Y | 12 | 12 | Merged | N | 39 | architecture_single-modality | Conceptual Captions | Classification | Y | Y | 12 | 6 | Sing. Modality | N | 40 | architecture_mixed-modality | Conceptual Captions | Classification | Y | Y | 12 | 6 | Mix Modality | N | 5 single modality layers and 1 merged layer 41 | architecture_single-stream | Conceptual Captions | Classification | Y | Y | 12 | 6 | Single Stream | N | 42 | architecture_language-q-12 | Conceptual Captions | Classification | Y | Y | 12 | 6 | Asymmetric (language) | N | 43 | architecture_image-q-12 | Conceptual Captions | Classification | Y | Y | 12 | 6 | Asymmetric (image) | N | 44 | architecture_language-q-24 | Conceptual Captions | Classification | Y | Y | 24 | 6 | Asymmetric (language) | N | 45 | architecture_image-q-24 | Conceptual Captions | Classification | Y | Y | 24 | 6 | Asymmetric (image) | N | 46 | architecture_single-modality-hloss | Conceptual Captions | Classification | Y | Y | 12 | 6 | Single modality | N | Includes ITM loss after every layer 47 | data-ft_sbu | SBU | Classification | Y | Y | 12 | 6 | Merged | Y | 48 | data-ft_vg | Visual Genome | Classification | Y | Y | 12 | 6 | Merged | Y | 49 | data-ft_mscoco | MSCOCO | Classification | Y | Y | 12 | 6 | Merged | Y | 50 | data-ft_mscoco-narratives | MSCOCO Narratives | Classification | Y | Y | 12 | 6 | Merged | Y | 51 | data-ft_oi-narratives | OI Narratives | Classification | Y | Y | 12 | 6 | Merged | Y | 52 | data-ft_cc | Conceptual Captions | Classification | Y | Y | 12 | 6 | Merged | Y | 53 | data-ft_combined-instance | All (instance sampling) | Classification | Y | Y | 12 | 6 | Merged | Y | 54 | data-ft_combined-dataset | All (dataset sampling) | Classification | Y | Y | 12 | 6 | Merged | Y | 55 | data-ft_uniter-instance | Uniter datasets (instance sampling) | Classification | Y | Y | 12 | 6 | Merged | Y | 56 | data-ft_uniter-dataset | Uniter datasets (dataset sampling) | Classification | Y | Y | 12 | 6 | Merged | Y | 57 | architecture-ft_single-modality | Conceptual Captions | Classification | Y | Y | 12 | 6 | Sing. Modality | Y | 58 | architecture-ft_single-stream | Conceptual Captions | Classification | Y | Y | 12 | 6 | Single Stream | Y | 59 | architecture-ft_language-q-12 | Conceptual Captions | Classification | Y | Y | 12 | 6 | Asymmetric (language) | Y | 60 | architecture-ft_image-q-12 | Conceptual Captions | Classification | Y | Y | 12 | 6 | Asymmetric (image) | Y | 61 | architecture-ft_language-q-24 | Conceptual Captions | Classification | Y | Y | 24 | 6 | Asymmetric (language) | Y | 62 | architecture-ft_image-q-24 | Conceptual Captions | Classification | Y | Y | 24 | 6 | Asymmetric (image) | Y | 63 | 64 | In addition to our transformer models, we also release our baseline models. See details of our baseline models in the chart below: 65 | 66 | | Name | ITM | Bert Initialisation | FineTuned | 67 | |---------------------------------------|----------------|---------------------|-----------| 68 | | baseline_baseline | Contrastive | Yes | N | 69 | | baseline_baseline-cls | Classification | No | N | 70 | | baseline_baseline-no-bert-transfer | Contrastive | No | N | 71 | | baseline-ft_baseline | Contrastive | Yes | Y | 72 | | baseline-ft_baseline-cls | Classification | No | Y | 73 | | baseline-ft_baseline-no-bert-transfer | Contrastive | No | 74 | 75 | ## Installation 76 | 77 | You do not need to install anything! You should be able to run all code 78 | from our released colab. 79 | 80 | ## Usage 81 | 82 | You can run an image and text pair through our module and see if the image and 83 | text pair match. 84 | 85 | ```python 86 | import tensorflow.compat.v1 as tf import tensorflow_hub as hub 87 | m = 88 | hub.Module('https://tfhub.dev/deepmind/mmt/architecture-ft_image-q-12/1') 89 | ``` 90 | 91 | Inference: 92 | 93 | ```python 94 | output = model.signatures['default'](**inputs) 95 | score = tf.nn.softmax(output['output']).numpy()[0] 96 | ``` 97 | 98 | where `score` indicates if an image-text pair match (`1` indicates a perfect 99 | match). Inputs is a dictionary with the following keys: 100 | 101 | * `image/bboxes`: Coordinates of detected image bounding boxes. 102 | 103 | * `image/detection_features`: Features from image detector. 104 | 105 | * `image/padding_mask`: Indicator if image features are padded. 106 | 107 | * `masked_tokens`: Text tokens 108 | 109 | * `text/segment_ids`: Indicates sentence segment. (Since we train with one sentencce this will always be 0.) 110 | 111 | * `text/token_ids`: Indicates which words tokens belong to. (We use a tokenizer which can break one word into multiple tokens). 112 | 113 | * `text/padding_mask`: Indicator if text features are padded. 114 | 115 | Please see our colab linked for details on pre-processing. 116 | You will need to use the detector released in our colab for good results. 117 | 118 | ## Citing this work 119 | 120 | If you use this model in your research please cite: 121 | 122 | [1] Lisa Anne Hendricks, John Mellor, Rosalia Schneider, Jean-Baptiste Alayrac, 123 | and Aida Nematzadeh. 124 | [Decoupling the Role of Data, Attention, and Losses 125 | in Multimodal Transformers](https://arxiv.org/pdf/2102.00529.pdf), 126 | TACL 2021. 127 | 128 | ## Disclaimer 129 | 130 | This is not an official Google product. 131 | -------------------------------------------------------------------------------- /mmt_release.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "y117dIimYVuW" 7 | }, 8 | "source": [ 9 | "This Colab generates language and image features to be used with pretrained image--language transformers. It then allows you to use our released models to determine if an image-text pair match!\n", 10 | "\n", 11 | "This replicates our retrieval results in our TACL 2021 paper:\n", 12 | "\n", 13 | "[Decoupling the Role of Data, Attention, and Losses in Multimodal Transformers](https://arxiv.org/abs/2102.00529)\n", 14 | "\n", 15 | "Paper Authors: Lisa Anne Hendricks, John Mellor, Rosalia Schneider, Jean-Baptiste Alayrac, and Aida Nematzadeh\n", 16 | "\n", 17 | "We also thank Sebastian Borgeaud and Cyprien de Masson d'Autume for their text preprocessing code." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "K_oXEX6F6_Rv" 24 | }, 25 | "source": [ 26 | "# Preproccessing Language and Images\n", 27 | "\n", 28 | "First, we use a detector to extract image features and SentencePiece to extract language tokens." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "vmooLaXxzPS-" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import os\n", 40 | "import numpy as np\n", 41 | "import tensorflow.compat.v1 as tf\n", 42 | "import tensorflow_hub as hub\n", 43 | "from io import BytesIO as StringIO\n", 44 | "from PIL import Image\n", 45 | "\n", 46 | "%matplotlib inline\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "import matplotlib.image as mpimg\n", 49 | "\n", 50 | "import matplotlib.image as mpimg\n", 51 | "import unicodedata" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "id": "kumubIq2g803" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "# Install the tensorflow Object Detection API...\n", 63 | "# If you're running this offline, you also might need to install the protobuf-compiler:\n", 64 | "# apt-get install protobuf-compiler\n", 65 | "\n", 66 | "! git clone -n https://github.com/tensorflow/models.git\n", 67 | "%cd models\n", 68 | "!git checkout 461b3587ef38b42cda151fa3b7d37706d77e4244\n", 69 | "%cd research\n", 70 | "! protoc object_detection/protos/*.proto --python_out=.\n", 71 | "\n", 72 | "# Install TensorFlow Object Detection API\n", 73 | "%cp object_detection/packages/tf2/setup.py .\n", 74 | "! python -m pip install --upgrade pip\n", 75 | "! python -m pip install --use-feature=2020-resolver .\n", 76 | "\n", 77 | "# Test the installation\n", 78 | "! python object_detection/builders/model_builder_tf2_test.py" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "id": "cyCSxln6hYDj" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "from object_detection.utils import visualization_utils as vis_util\n", 90 | "from object_detection.utils import label_map_util\n", 91 | "from object_detection.core import standard_fields as fields" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "id": "0OCoLQ6msv0X" 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "!wget https://storage.googleapis.com/dm-mmt-models/spiece.model -P '/tmp'" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "id": "3Q9VtFoEJeMY" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "features = {} # input to our model " 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "oiZXWgUln-nM" 120 | }, 121 | "source": [ 122 | "## Language Preprocessing" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": { 128 | "id": "c6WQsSwvoCwv" 129 | }, 130 | "source": [ 131 | "### Helper Functions for Preprocessing Text" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "id": "z-l7ezeGlN3o" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "SPIECE_UNDERLINE = '▁' # pylint: disable=invalid-encoded-data\n", 143 | "\n", 144 | "special_symbols = {\n", 145 | " '\u003ccls\u003e': 3,\n", 146 | " '\u003csep\u003e': 4,\n", 147 | " '\u003cpad\u003e': 5,\n", 148 | " '\u003cmask\u003e': 6,\n", 149 | "}\n", 150 | "CLS_ID = special_symbols['\u003ccls\u003e']\n", 151 | "SEP_ID = special_symbols['\u003csep\u003e']\n", 152 | "PAD_ID = special_symbols['\u003cpad\u003e']\n", 153 | "MASK_ID = special_symbols['\u003cmask\u003e']\n", 154 | "\n", 155 | "def is_start_piece(piece):\n", 156 | " \"\"\"Returns True if the piece is a start piece for a word/symbol.\"\"\"\n", 157 | " special_pieces = set(list('!\"#$%\u0026\\\"()*+,-./:;?@[\\\\]^_`{|}~'))\n", 158 | " if piece.startswith(SPIECE_UNDERLINE):\n", 159 | " return True\n", 160 | " if piece.startswith('\u003c'):\n", 161 | " return True\n", 162 | " if piece in special_pieces:\n", 163 | " return True\n", 164 | " return False\n", 165 | "\n", 166 | "\n", 167 | "def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):\n", 168 | " \"\"\"Preprocess the inputs.\"\"\"\n", 169 | " if remove_space:\n", 170 | " outputs = ' '.join(inputs.strip().split())\n", 171 | " else:\n", 172 | " outputs = inputs\n", 173 | " outputs = outputs.replace('``', '\"').replace('\\'\\'', '\"')\n", 174 | "\n", 175 | " if not keep_accents:\n", 176 | " outputs = unicodedata.normalize('NFKD', outputs)\n", 177 | " outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])\n", 178 | " if lower:\n", 179 | " outputs = outputs.lower()\n", 180 | "\n", 181 | " return outputs\n", 182 | "\n", 183 | "\n", 184 | "def encode_pieces(sp_model, text, sample=False):\n", 185 | " \"\"\"Encode the text to pieces using the given SentencePiece model sp_model.\"\"\"\n", 186 | " if not sample:\n", 187 | " pieces = sp_model.EncodeAsPieces(text)\n", 188 | " else:\n", 189 | " pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)\n", 190 | " new_pieces = []\n", 191 | " for piece in pieces:\n", 192 | " if len(piece) \u003e 1 and piece[-1] == ',' and piece[-2].isdigit():\n", 193 | " cur_pieces = sp_model.EncodeAsPieces(\n", 194 | " piece[:-1].replace(SPIECE_UNDERLINE, ''))\n", 195 | " if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:\n", 196 | " if len(cur_pieces[0]) == 1:\n", 197 | " cur_pieces = cur_pieces[1:]\n", 198 | " else:\n", 199 | " cur_pieces[0] = cur_pieces[0][1:]\n", 200 | " cur_pieces.append(piece[-1])\n", 201 | " new_pieces.extend(cur_pieces)\n", 202 | " else:\n", 203 | " new_pieces.append(piece)\n", 204 | "\n", 205 | " return new_pieces\n", 206 | "\n", 207 | "\n", 208 | "def encode_ids(sp_model, text, sample=False):\n", 209 | " pieces = encode_pieces(sp_model, text, sample=sample)\n", 210 | " ids = [sp_model.PieceToId(piece) for piece in pieces]\n", 211 | " return ids\n", 212 | "\n", 213 | "\n", 214 | "def tokens_to_word_indices(sp_model, tokens, offset=0):\n", 215 | " \"\"\"Compute the word ids for the tokens.\n", 216 | "\n", 217 | " The word indices start at offset, each time a new word is encountered, the\n", 218 | " word id is increased by 1.\n", 219 | "\n", 220 | " Args:\n", 221 | " tokens: `list` of `int` SentencePiece tokens\n", 222 | " offset: `int` start index\n", 223 | "\n", 224 | " Returns:\n", 225 | " A `list` of increasing integers. If element i and j are identical, then\n", 226 | " tokens[i] and tokens[j] are part of the same word.\n", 227 | " \"\"\"\n", 228 | " word_indices = []\n", 229 | " current_index = offset\n", 230 | " for i, token in enumerate(tokens):\n", 231 | " token_piece = sp_model.IdToPiece(token)\n", 232 | " if i \u003e 0 and is_start_piece(token_piece):\n", 233 | " current_index += 1\n", 234 | " word_indices.append(current_index)\n", 235 | "\n", 236 | " return word_indices" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "id": "um98ejShoRNB" 243 | }, 244 | "source": [ 245 | "### Load the SentencePiece Model" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "id": "ENAioP-BoOgZ" 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "import sentencepiece as sp\n", 257 | "spm_path = '/tmp/spiece.model'\n", 258 | "spm = sp.SentencePieceProcessor()\n", 259 | "spm.Load(spm_path)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": { 265 | "id": "yOIiWM_WoP5K" 266 | }, 267 | "source": [ 268 | "### Preprocessing Captions " 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "id": "fA-h8RzDKRqT" 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "def create_sentence_features(seq_len, spm, captions, max_sentence_number=1):\n", 280 | " def _add_sentence_pad():\n", 281 | " pad_number = max_sentence_number - len(captions)\n", 282 | " for _ in range(pad_number):\n", 283 | " all_sents['tokens'] += [MASK_ID] * seq_len\n", 284 | " all_sents['segment_ids'] += [0] * seq_len\n", 285 | " all_sents['padding_mask'] += [1] * seq_len\n", 286 | " all_sents['word_ids'] += [-2] * seq_len\n", 287 | "\n", 288 | " # Limit the sentence length to seq_len\n", 289 | " # We concatenate all sentences after checking the seq len and adding\n", 290 | " # padding\n", 291 | " all_sents = {}\n", 292 | " for k in ['tokens', 'segment_ids', 'padding_mask', 'word_ids']:\n", 293 | " all_sents[k] = []\n", 294 | "\n", 295 | " for sentence in captions:\n", 296 | " sentence = preprocess_text(sentence, remove_space=True, lower=True, keep_accents=False)\n", 297 | "\n", 298 | " tokens = encode_ids(spm, sentence)\n", 299 | " if len(tokens) \u003e= seq_len - 2:\n", 300 | " tokens = tokens[:seq_len - 2] # since we add two symbols\n", 301 | "\n", 302 | " word_ids = tokens_to_word_indices(spm, tokens)\n", 303 | " word_ids = ([-1] + word_ids + [-1])\n", 304 | " # Need to create segment ids before adding special symbols to tokens\n", 305 | " segment_ids = ([0] + # SEP\n", 306 | " [0] * len(tokens) + [2] # CLS\n", 307 | " )\n", 308 | " tokens = ([SEP_ID] + tokens + [CLS_ID])\n", 309 | " padding_mask = [0] * len(tokens)\n", 310 | " # Note, we add padding at the start so that the last token is always [CLS]\n", 311 | "\n", 312 | " if len(tokens) \u003c seq_len:\n", 313 | " padding_len = seq_len - len(tokens)\n", 314 | " tokens = [MASK_ID] * padding_len + tokens\n", 315 | " \n", 316 | " segment_ids = [0] * padding_len + segment_ids\n", 317 | " padding_mask = [1] * padding_len + padding_mask\n", 318 | " word_ids = [-2] * padding_len + word_ids\n", 319 | "\n", 320 | "\n", 321 | " assert len(tokens) == seq_len\n", 322 | " assert len(segment_ids) == seq_len\n", 323 | " assert len(padding_mask) == seq_len\n", 324 | " assert len(word_ids) == seq_len\n", 325 | "\n", 326 | " all_sents['tokens'] += tokens\n", 327 | " all_sents['segment_ids'] += segment_ids\n", 328 | " all_sents['padding_mask'] += padding_mask\n", 329 | " all_sents['word_ids'] += word_ids\n", 330 | "\n", 331 | " # Add padding sentences to the end so that each example has\n", 332 | " # max_sentence_number\n", 333 | " if len(captions) \u003c max_sentence_number:\n", 334 | " _add_sentence_pad()\n", 335 | "\n", 336 | " return {\n", 337 | " 'text/token_ids': np.array(all_sents['tokens'], dtype=np.int32),\n", 338 | " 'text/segment_ids': np.array(all_sents['segment_ids'], dtype=np.int32),\n", 339 | " 'text/padding_mask': np.array(all_sents['padding_mask'], dtype=np.int32),\n", 340 | " 'text/word_ids': np.array(all_sents['word_ids'], dtype=np.int32),\n", 341 | " 'text/sentence_num': len(captions),\n", 342 | " }\n", 343 | " \n" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "dlV7HY-SLEj4" 350 | }, 351 | "source": [ 352 | "Get features for an example caption." 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": { 359 | "id": "VXYlQ_rANxtZ" 360 | }, 361 | "outputs": [], 362 | "source": [ 363 | "features = create_sentence_features(seq_len=25, spm=spm, captions=['A man with a backpack holding a kitten.'])\n", 364 | "print(features)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": { 370 | "id": "2_uRSZXmodsj" 371 | }, 372 | "source": [ 373 | "## Image Preprocessing" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": { 379 | "id": "PczZB4Z4FvUt" 380 | }, 381 | "source": [ 382 | "###Load the Pretrained Object Detector " 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": { 389 | "id": "-xSz_GEZFoH0" 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "def LoadInferenceGraph(inference_graph_path):\n", 394 | " \"\"\"Loads inference graph into tensorflow Graph object.\n", 395 | "\n", 396 | " Args:\n", 397 | " inference_graph_path: Path to inference graph.\n", 398 | "\n", 399 | " Returns:\n", 400 | " a tf.Graph object.\n", 401 | " \"\"\"\n", 402 | " od_graph = tf.Graph()\n", 403 | " with od_graph.as_default():\n", 404 | " od_graph_def = tf.GraphDef()\n", 405 | " with open(inference_graph_path, 'rb') as fid:\n", 406 | " serialized_graph = fid.read()\n", 407 | " od_graph_def.ParseFromString(serialized_graph)\n", 408 | " tf.import_graph_def(od_graph_def, name='')\n", 409 | " return od_graph" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": { 415 | "id": "CoOK2n0ZLwVi" 416 | }, 417 | "source": [ 418 | "Download the pretrained object detector." 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": { 425 | "id": "zZaA3ddA4-IW" 426 | }, 427 | "outputs": [], 428 | "source": [ 429 | "!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/frozen_inference_graph.pb -P '/tmp'" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "id": "3qocil4KiP4K" 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "detection_graph = LoadInferenceGraph('/tmp/frozen_inference_graph.pb')\n", 441 | "print ('Successfully loaded frozen model from {}'.format('https://storage.googleapis.com/dm-mmt-models/frozen_inference_graph.pb'))" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "_YkafsKflGkk" 448 | }, 449 | "source": [ 450 | "### Load an Example Image" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": { 457 | "id": "e0hSW3qdmHXC" 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "def LoadImageIntoNumpyArray(path):\n", 462 | "\n", 463 | " with open(path, 'rb') as img_file: \n", 464 | " img = mpimg.imread(img_file)\n", 465 | " (im_width, im_height) = img.shape[:2]\n", 466 | " return img[:,:,:3].astype(np.uint8)" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": { 473 | "id": "SfAsIi_z5OY4" 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "# Download the image\n", 478 | "!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/COCO_val2014_000000570107.jpeg -P '/tmp/' \n", 479 | "image_np = LoadImageIntoNumpyArray('/tmp/COCO_val2014_000000570107.jpeg')\n", 480 | "\n", 481 | "print('image type: %s' % str(image_np.dtype))\n", 482 | "print('image shape: %s' % str(image_np.shape))\n" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": { 488 | "id": "MpVO4bCPPcrt" 489 | }, 490 | "source": [ 491 | "###Preprocessing Images" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": { 497 | "id": "7r0qXvmzQDmG" 498 | }, 499 | "source": [ 500 | "Loading the object-label mappings for the dectector." 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "id": "B_Fa0dI14Z5W" 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/objatt_labelmap.txt -P '/tmp/' \n", 512 | "label_map_path = '/tmp/objatt_labelmap.txt'\n", 513 | "categories = label_map_util.create_categories_from_labelmap(label_map_path, use_display_name=True)\n", 514 | "category_index = label_map_util.create_category_index(categories)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": { 520 | "id": "RrxWiy6hPlfy" 521 | }, 522 | "source": [ 523 | "Running inference on the object detector for a single image." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": { 530 | "id": "O7LeGadUzgeE" 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "def RunInferenceSingleImage(image, graph):\n", 535 | " \"\"\"Run single image through tensorflow object detection graph.\n", 536 | "\n", 537 | " This function runs an inference graph (frozen using the functions provided\n", 538 | " in this file) on a (single) provided image and returns inference results in\n", 539 | " numpy arrays.\n", 540 | "\n", 541 | " Args:\n", 542 | " image: uint8 numpy array with shape (img_height, img_width, 3)\n", 543 | " graph: tensorflow graph object holding loaded model. This graph can be\n", 544 | " obtained by running the LoadInferenceGraph function above.\n", 545 | "\n", 546 | " Returns:\n", 547 | " output_dict: a dictionary holding the following entries:\n", 548 | " `num_detections`: an integer\n", 549 | " `detection_boxes`: a numpy (float32) array of shape [N, 4]\n", 550 | " `detection_classes`: a numpy (uint8) array of shape [N]\n", 551 | " `detection_scores`: a numpy (float32) array of shape [N]\n", 552 | " `detection_masks`: a numpy (uint8) array of shape\n", 553 | " [N, image_height, image_width] with values in {0, 1}\n", 554 | " `detection_keypoints`: a numpy (float32) array of shape\n", 555 | " [N, num_keypoints, 2]\n", 556 | " \"\"\"\n", 557 | " with graph.as_default():\n", 558 | " with tf.Session() as sess:\n", 559 | " # Get handles to input and output tensors\n", 560 | " ops = tf.get_default_graph().get_operations()\n", 561 | " all_tensor_names = {output.name for op in ops for output in op.outputs}\n", 562 | " tensor_dict = {}\n", 563 | " detection_fields = fields.DetectionResultFields\n", 564 | " for key in [\n", 565 | " v for k, v in vars(detection_fields).items()\n", 566 | " if not k.startswith('__')\n", 567 | " ]:\n", 568 | " tensor_name = key + ':0'\n", 569 | " if tensor_name in all_tensor_names:\n", 570 | " tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(\n", 571 | " tensor_name)\n", 572 | " image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')\n", 573 | "\n", 574 | " # Run inference\n", 575 | " output_dict = sess.run(tensor_dict,\n", 576 | " feed_dict={image_tensor: np.expand_dims(image, 0)})\n", 577 | "\n", 578 | " # all outputs are float32 numpy arrays, so convert types as appropriate\n", 579 | " output_dict['num_detections'] = int(output_dict['num_detections'][0])\n", 580 | " output_dict['detection_classes'] = output_dict[\n", 581 | " 'detection_classes'][0].astype(np.uint8)\n", 582 | " output_dict['detection_boxes'] = output_dict['detection_boxes'][0]\n", 583 | " output_dict['detection_scores'] = output_dict['detection_scores'][0]\n", 584 | " if 'detection_masks' in output_dict:\n", 585 | " output_dict['detection_masks'] = output_dict['detection_masks'][0]\n", 586 | " if 'detection_keypoints' in output_dict:\n", 587 | " output_dict['detection_keypoints'] = output_dict['detection_keypoints'][\n", 588 | " 0]\n", 589 | " return output_dict" 590 | ] 591 | }, 592 | { 593 | "cell_type": "markdown", 594 | "metadata": { 595 | "id": "mdbzypy9MtLk" 596 | }, 597 | "source": [ 598 | "Pass an Image Through the Detector." 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "metadata": { 605 | "id": "5e23DVGLlc6o" 606 | }, 607 | "outputs": [], 608 | "source": [ 609 | "# Run inference\n", 610 | "output_dict = RunInferenceSingleImage(image_np, detection_graph)\n", 611 | "output_dict['detection_features'].shape" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "id": "G-e9VRV_8NI5" 619 | }, 620 | "outputs": [], 621 | "source": [ 622 | "output_dict.keys()" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "metadata": { 628 | "id": "Yni8oUuyQYuS" 629 | }, 630 | "source": [ 631 | "Preprocessing the output of the detector to be readable by our models." 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": null, 637 | "metadata": { 638 | "id": "hyoXWQYKPQ91" 639 | }, 640 | "outputs": [], 641 | "source": [ 642 | "image_seq_num = 100 \n", 643 | "image_feat = {}\n", 644 | "image_feat['height'] = image_np.shape[0]\n", 645 | "image_feat['width'] = image_np.shape[1]\n", 646 | "\n", 647 | "raw_feats = np.mean(np.mean(output_dict['detection_features'], axis=-2), axis=-2).squeeze() # image_feat['detection_features']\n", 648 | "num_detections = output_dict['num_detections']\n", 649 | "\n", 650 | "raw_scores = output_dict['detection_multiclass_scores'][:, :num_detections, ...].squeeze()\n", 651 | "\n", 652 | "# Find regions with highest class scores\n", 653 | "sorted_score_idxs = np.argsort(np.max(raw_scores[:, 1:], axis=-1))[::-1]\n", 654 | "\n", 655 | "# Collect features, boxes, and scores for highest scoring regions\n", 656 | "detection_feats = np.zeros((image_seq_num + 1, raw_feats.shape[-1]))\n", 657 | "detection_scores = np.zeros((image_seq_num + 1, raw_scores.shape[-1]))\n", 658 | "bbox_feats = np.zeros((image_seq_num + 1, 5))\n", 659 | "image_padding = np.ones((image_seq_num + 1,))\n", 660 | "padding_offset = max(image_seq_num + 1 - sorted_score_idxs.shape[0], 0)\n", 661 | "\n", 662 | "for i, index in enumerate(sorted_score_idxs[:image_seq_num]):\n", 663 | " padded_index = i + padding_offset\n", 664 | " detection_feats[padded_index, :] = raw_feats[index, :]\n", 665 | " detection_scores[padded_index, :] = raw_scores[index, :]\n", 666 | " # index 0 is 'background'\n", 667 | " bbox_feats[padded_index, :4] = output_dict['detection_boxes'][index, :]\n", 668 | " bbox_w = (output_dict['detection_boxes'][index, 3] -\n", 669 | " output_dict['detection_boxes'][index, 1]) * image_feat['width']\n", 670 | " bbox_h = (output_dict['detection_boxes'][index, 2] -\n", 671 | " output_dict['detection_boxes'][index, 0]) * image_feat['height']\n", 672 | " bbox_area = (bbox_w * bbox_h) / (image_feat['height'] * image_feat['width'])\n", 673 | " bbox_feats[padded_index, -1] = bbox_area\n", 674 | " image_padding[padded_index] = 0\n", 675 | "\n", 676 | "# Add in global image feature\n", 677 | "detection_feats[-1, :]= np.mean(detection_feats[padding_offset:-1, ...], axis=0).squeeze()\n", 678 | "bbox_feats[-1, :] = [0, 0, 1, 1, 1]\n", 679 | "image_padding[-1] = 0\n", 680 | "\n", 681 | "features.update(\n", 682 | " {'image/bboxes': bbox_feats.astype(np.float32),\n", 683 | " 'image/padding_mask': image_padding.astype(np.int32), \n", 684 | " 'image/detection_features': detection_feats.astype(np.float32),\n", 685 | " 'image/detection_scores': detection_scores.astype(np.float32)}) \n" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "metadata": { 692 | "id": "dsS_lflb3lO5" 693 | }, 694 | "outputs": [], 695 | "source": [ 696 | "print(features['image/bboxes'].shape)\n", 697 | "print(features['image/detection_features'].shape)" 698 | ] 699 | }, 700 | { 701 | "cell_type": "markdown", 702 | "metadata": { 703 | "id": "tgqt6SFBqEeH" 704 | }, 705 | "source": [ 706 | "### Visualizing the Detector Regions " 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": null, 712 | "metadata": { 713 | "id": "5bIAo59YyOu0" 714 | }, 715 | "outputs": [], 716 | "source": [ 717 | "%matplotlib inline\n", 718 | "\n", 719 | "detection_classes = []\n", 720 | "detection_scores = []\n", 721 | "tuplet_index = {}\n", 722 | "\n", 723 | "for i in range(100):\n", 724 | " raw_detection_scores_obj = output_dict['detection_multiclass_scores'][:,i,1:1600][0,:]\n", 725 | " raw_detection_scores_att = output_dict['detection_multiclass_scores'][:,i,1600:][0,:]\n", 726 | " max_obj = np.argmax(raw_detection_scores_obj)\n", 727 | " max_att = np.argmax(raw_detection_scores_att)\n", 728 | " tuplet_index[i] = {}\n", 729 | " tuplet_index[i]['name'] = '%s %s' %(category_index[max_att+1600]['name'],\n", 730 | " category_index[max_obj+1]['name'])\n", 731 | " detection_classes.append(i)\n", 732 | " detection_scores.append(raw_detection_scores_obj[max_obj] +\n", 733 | " raw_detection_scores_att[max_att])\n", 734 | "\n", 735 | "# Create detections visualization\n", 736 | "bboxes = vis_util.visualize_boxes_and_labels_on_image_array(\n", 737 | " image_np.copy(),\n", 738 | " output_dict['detection_boxes'],\n", 739 | " np.array(detection_classes),\n", 740 | " detection_scores,\n", 741 | " tuplet_index,\n", 742 | " instance_masks=None,\n", 743 | " use_normalized_coordinates=True,\n", 744 | " max_boxes_to_draw=15,\n", 745 | " min_score_thresh=.05,\n", 746 | " agnostic_mode=False)\n", 747 | "\n", 748 | "fig = plt.gcf()\n", 749 | "fig.set_size_inches(18.5, 10.5)\n", 750 | "_ = plt.imshow(bboxes)\n", 751 | "plt.axis('off')" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "metadata": { 757 | "id": "3nlv5_IyX53t" 758 | }, 759 | "source": [ 760 | "# Running Image-Text Pairs through the MMT\n", 761 | "\n", 762 | "Now that we have extracted our image and text features we can run them through our MMT model." 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": { 768 | "id": "TJpau4aEYIqO" 769 | }, 770 | "source": [ 771 | "## Use features extracted in colab" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": null, 777 | "metadata": { 778 | "id": "Y5q0zjE3oY9_" 779 | }, 780 | "outputs": [], 781 | "source": [ 782 | "# Select a model\n", 783 | "\n", 784 | "#@title Category-conditional sampling { display-mode: \"form\", run: \"auto\" }\n", 785 | "\n", 786 | "tags = ['architecture-ft_image-q-12',\n", 787 | " 'architecture-ft_image-q-24',\n", 788 | " 'architecture-ft_language-q-12',\n", 789 | " 'architecture-ft_language-q-24',\n", 790 | " 'architecture-ft_single-modality',\n", 791 | " 'architecture-ft_single-stream',\n", 792 | " 'architecture_heads1-768',\n", 793 | " 'architecture_heads18-64',\n", 794 | " 'architecture_heads3-256',\n", 795 | " 'architecture_heads6-64',\n", 796 | " 'architecture_image-q-12',\n", 797 | " 'architecture_image-q-24',\n", 798 | " 'architecture_language-q-12',\n", 799 | " 'architecture_language-q-24',\n", 800 | " 'architecture_mixed-modality',\n", 801 | " 'architecture_single-modality',\n", 802 | " 'architecture_single-modality-hloss',\n", 803 | " 'architecture_single-stream',\n", 804 | " 'architecture_vilbert-12block',\n", 805 | " 'architecture_vilbert-1block',\n", 806 | " 'architecture_vilbert-2block',\n", 807 | " 'architecture_vilbert-4block',\n", 808 | " 'baseline-ft_baseline',\n", 809 | " 'baseline-ft_baseline-cls',\n", 810 | " 'baseline-ft_baseline-no-bert-transfer',\n", 811 | " 'baseline_baseline',\n", 812 | " 'baseline_baseline-cls',\n", 813 | " 'baseline_baseline-no-bert-transfer',\n", 814 | " 'data-ft_cc',\n", 815 | " 'data-ft_combined-dataset',\n", 816 | " 'data-ft_combined-instance',\n", 817 | " 'data-ft_mscoco',\n", 818 | " 'data-ft_mscoco-narratives',\n", 819 | " 'data-ft_oi-narratives',\n", 820 | " 'data-ft_sbu',\n", 821 | " 'data-ft_uniter-dataset',\n", 822 | " 'data-ft_uniter-instance',\n", 823 | " 'data-ft_vg',\n", 824 | " 'data_cc',\n", 825 | " 'data_cc-with-bert',\n", 826 | " 'data_combined-dataset',\n", 827 | " 'data_combined-instance',\n", 828 | " 'data_mscoco',\n", 829 | " 'data_mscoco-narratives',\n", 830 | " 'data_oi-narratives',\n", 831 | " 'data_sbu',\n", 832 | " 'data_uniter-dataset',\n", 833 | " 'data_uniter-instance',\n", 834 | " 'data_vg',\n", 835 | " 'loss_itm+mrm',\n", 836 | " 'loss_itm_mrm',\n", 837 | " 'loss_single-modality-contrastive1024',\n", 838 | " 'loss_single-modality-contrastive32',\n", 839 | " 'loss_v1-contrastive32',\n", 840 | " 'pixel_vilbert_cc-full-image']\n", 841 | "\n", 842 | "model = \"data_cc\" #@param [\"architecture-ft_image-q-12\", \"architecture-ft_image-q-24\", \"architecture-ft_language-q-12\", \"architecture-ft_language-q-24\", \"architecture-ft_single-modality\", \"architecture-ft_single-stream\", \"architecture_heads1-768\", \"architecture_heads18-64\", \"architecture_heads3-256\", \"architecture_heads6-64\", \"architecture_image-q-12\", \"architecture_image-q-24\", \"architecture_language-q-12\", \"architecture_language-q-24\", \"architecture_mixed-modality\", \"architecture_single-modality\", \"architecture_single-modality-hloss\", \"architecture_single-stream\", \"architecture_vilbert-12block\", \"architecture_vilbert-1block\", \"architecture_vilbert-2block\", \"architecture_vilbert-4block\", \"baseline-ft_baseline\", \"baseline-ft_baseline-cls\", \"baseline-ft_baseline-no-bert-transfer\", \"baseline_baseline\", \"baseline_baseline-cls\", \"baseline_baseline-no-bert-transfer\", \"data-ft_cc\", \"data-ft_combined-dataset\", \"data-ft_combined-instance\", \"data-ft_mscoco\", \"data-ft_mscoco-narratives\", \"data-ft_oi-narratives\", \"data-ft_sbu\", \"data-ft_uniter-dataset\", \"data-ft_uniter-instance\", \"data-ft_vg\", \"data_cc\", \"data_cc-with-bert\", \"data_combined-dataset\", \"data_combined-instance\", \"data_mscoco\", \"data_mscoco-narratives\", \"data_oi-narratives\", \"data_sbu\", \"data_uniter-dataset\", \"data_uniter-instance\", \"data_vg\", \"loss_itm+mrm\", \"loss_itm_mrm\", \"loss_single-modality-contrastive1024\", \"loss_single-modality-contrastive32\", \"loss_v1-contrastive32\"]\n", 843 | "\n", 844 | "tfhub_link = \"https://tfhub.dev/deepmind/mmt/%s/1\" %model" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": null, 850 | "metadata": { 851 | "id": "5ugmFX5OYOfi" 852 | }, 853 | "outputs": [], 854 | "source": [ 855 | "model = hub.load(tfhub_link)" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": { 862 | "id": "Bhcy1XN3YvIw" 863 | }, 864 | "outputs": [], 865 | "source": [ 866 | "inputs={'image/bboxes': tf.expand_dims(features['image/bboxes'], 0),\n", 867 | " 'text/padding_mask': tf.expand_dims(features['text/padding_mask'], 0),\n", 868 | " 'image/padding_mask': tf.expand_dims(features['image/padding_mask'], 0),\n", 869 | " 'masked_tokens': tf.expand_dims(features['text/token_ids'], 0),\n", 870 | " 'text/segment_ids': tf.expand_dims(features['text/segment_ids'], 0),\n", 871 | " 'image/detection_features': tf.expand_dims(features['image/detection_features'], 0),\n", 872 | " 'text/token_ids': tf.expand_dims(features['text/token_ids'], 0)\n", 873 | " }\n", 874 | "\n", 875 | "output = model.signatures['default'](**inputs)\n", 876 | "score = tf.nn.softmax(output['output']).numpy()[0]\n", 877 | "\n", 878 | "if score \u003e 0.5:\n", 879 | " print('The text and image match! (score: %0.03f)' %score)\n", 880 | "else: \n", 881 | " print('The text and image do not match :( (score: %0.03f)' %score) " 882 | ] 883 | }, 884 | { 885 | "cell_type": "markdown", 886 | "metadata": { 887 | "id": "KwlfRwMaVge6" 888 | }, 889 | "source": [ 890 | "# Running with Pre-Extracted Features\n", 891 | "\n", 892 | "We have pre-extracted MSCOCO and Flickr image features. You can uset these pre-extracted features to do retrieval." 893 | ] 894 | }, 895 | { 896 | "cell_type": "markdown", 897 | "metadata": { 898 | "id": "3cEyRvb-YO8K" 899 | }, 900 | "source": [ 901 | "## Use Precomputed features" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": null, 907 | "metadata": { 908 | "id": "39CJERs_PJfd" 909 | }, 910 | "outputs": [], 911 | "source": [ 912 | "import pickle as pkl\n", 913 | "\n", 914 | "!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/features/coco_test/570107.pkl -P '/tmp/' \n", 915 | "with open('/tmp/570107.pkl', 'rb') as f:\n", 916 | " im_feats = pkl.load(f)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": null, 922 | "metadata": { 923 | "id": "H_41h80NYH3o" 924 | }, 925 | "outputs": [], 926 | "source": [ 927 | "inputs={'image/bboxes': tf.expand_dims(features['image/bboxes'], 0),\n", 928 | " 'text/padding_mask': tf.expand_dims(features['text/padding_mask'], 0),\n", 929 | " 'image/padding_mask': tf.expand_dims(im_feats['image/padding_mask'], 0),\n", 930 | " 'masked_tokens': tf.expand_dims(features['text/token_ids'], 0),\n", 931 | " 'text/segment_ids': tf.expand_dims(features['text/segment_ids'], 0),\n", 932 | " 'image/detection_features': tf.expand_dims(im_feats['image/detection_features'], 0),\n", 933 | " 'text/token_ids': tf.expand_dims(features['text/token_ids'], 0)\n", 934 | " }\n", 935 | "\n", 936 | "output = model.signatures['default'](**inputs)\n", 937 | "score = tf.nn.softmax(output['output']).numpy()[0]\n", 938 | "\n", 939 | "if score \u003e 0.5:\n", 940 | " print('The text and image match! (score: %0.03f)' %score)\n", 941 | "else: \n", 942 | " print('The text and image do not match :( (score: %0.03f)' %score) " 943 | ] 944 | } 945 | ], 946 | "metadata": { 947 | "colab": { 948 | "name": "mmt_release.ipynb", 949 | "private_outputs": true, 950 | "provenance": [ 951 | { 952 | "file_id": "1y-QoDky8RQTvM6shgfYguEJs8Bg4F3jg", 953 | "timestamp": 1635456767570 954 | }, 955 | { 956 | "file_id": "1HEb0vz63HwqhY7vafLTetV3r4dAZ7iNV", 957 | "timestamp": 1628272609748 958 | } 959 | ] 960 | }, 961 | "kernelspec": { 962 | "display_name": "Python 3", 963 | "name": "python3" 964 | }, 965 | "language_info": { 966 | "name": "python" 967 | } 968 | }, 969 | "nbformat": 4, 970 | "nbformat_minor": 0 971 | } 972 | --------------------------------------------------------------------------------