├── .github ├── airplane.png ├── cake.png ├── cruise.png ├── girl.png ├── horse.png └── office.png ├── .gitignore ├── LICENSE ├── README.md ├── catr_demo.ipynb ├── configuration.py ├── datasets ├── __init__.py ├── coco.py └── utils.py ├── engine.py ├── finetune.py ├── hubconf.py ├── main.py ├── models ├── __init__.py ├── backbone.py ├── caption.py ├── position_encoding.py ├── transformer.py └── utils.py ├── predict.py └── requirements.txt /.github/airplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/airplane.png -------------------------------------------------------------------------------- /.github/cake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/cake.png -------------------------------------------------------------------------------- /.github/cruise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/cruise.png -------------------------------------------------------------------------------- /.github/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/girl.png -------------------------------------------------------------------------------- /.github/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/horse.png -------------------------------------------------------------------------------- /.github/office.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/office.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # End of https://www.toptal.com/developers/gitignore/api/python 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **CA⫶TR**: Image Captioning with Transformers 2 | ======== 3 | PyTorch training code and pretrained models for **CATR** (**CA**ption **TR**ansformer). 4 | 5 | The models are also available via torch hub, 6 | to load model with pretrained weights simply do: 7 | ```python 8 | model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True) # you can choose between v1, v2 and v3 9 | ``` 10 | ### Samples: 11 | 12 |

13 | 14 | 15 | 16 | 17 | 18 |

19 | 20 | All these images has been annotated by CATR. 21 | 22 | Test with your own bunch of images: 23 | ````bash 24 | $ python predict.py --path /path/to/image --v v2 // You can choose between v1, v2, v3 [default is v3] 25 | ```` 26 | Or Try it out in colab [notebook](catr_demo.ipynb) 27 | 28 | # Usage 29 | There are no extra compiled components in CATR and package dependencies are minimal, 30 | so the code is very simple to use. We provide instructions how to install dependencies. 31 | First, clone the repository locally: 32 | ``` 33 | $ git clone https://github.com/saahiluppal/catr.git 34 | ``` 35 | Then, install PyTorch 1.5+ and torchvision 0.6+ along with remaining dependencies: 36 | ``` 37 | $ pip install -r requirements.txt 38 | ``` 39 | That's it, should be good to train and test caption models. 40 | 41 | ## Data preparation 42 | 43 | Download and extract COCO 2017 train and val images with annotations from 44 | [http://cocodataset.org](http://cocodataset.org/#download). 45 | We expect the directory structure to be the following: 46 | ``` 47 | path/to/coco/ 48 | annotations/ # annotation json files 49 | train2017/ # train images 50 | val2017/ # val images 51 | ``` 52 | 53 | ## Training 54 | Tweak the hyperparameters from configuration file. 55 | 56 | To train baseline CATR on a single GPU for 30 epochs run: 57 | ``` 58 | $ python main.py 59 | ``` 60 | We train CATR with AdamW setting learning rate in the transformer to 1e-4 and 1e-5 in the backbone. 61 | Horizontal flips, scales an crops are used for augmentation. 62 | Images are rescaled to have max size 299. 63 | The transformer is trained with dropout of 0.1, and the whole model is trained with grad clip of 0.1. 64 | 65 | ## Testing 66 | To test CATR with your own images. 67 | ``` 68 | $ python predict.py --path /path/to/image --v v2 // You can choose between v1, v2, v3 [default is v3] 69 | ``` 70 | 71 | # License 72 | CATR is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information. 73 | -------------------------------------------------------------------------------- /catr_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "catr-demo.ipynb", 7 | "provenance": [], 8 | "mount_file_id": "1LVtj77AoMmhJmTDCsXmcQjqxUFb7YEJ2", 9 | "authorship_tag": "ABX9TyNQVSIJZ6Hoj2qvhWWnmSyV", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "N8xyiUNoWmL3", 33 | "colab_type": "code", 34 | "colab": { 35 | "base_uri": "https://localhost:8080/", 36 | "height": 121 37 | }, 38 | "outputId": "6c54de9e-339e-47a6-ff49-50dfe361d627" 39 | }, 40 | "source": [ 41 | "!git clone https://github.com/saahiluppal/catr.git" 42 | ], 43 | "execution_count": 2, 44 | "outputs": [ 45 | { 46 | "output_type": "stream", 47 | "text": [ 48 | "Cloning into 'catr'...\n", 49 | "remote: Enumerating objects: 79, done.\u001b[K\n", 50 | "remote: Counting objects: 1% (1/79)\u001b[K\rremote: Counting objects: 2% (2/79)\u001b[K\rremote: Counting objects: 3% (3/79)\u001b[K\rremote: Counting objects: 5% (4/79)\u001b[K\rremote: Counting objects: 6% (5/79)\u001b[K\rremote: Counting objects: 7% (6/79)\u001b[K\rremote: Counting objects: 8% (7/79)\u001b[K\rremote: Counting objects: 10% (8/79)\u001b[K\rremote: Counting objects: 11% (9/79)\u001b[K\rremote: Counting objects: 12% (10/79)\u001b[K\rremote: Counting objects: 13% (11/79)\u001b[K\rremote: Counting objects: 15% (12/79)\u001b[K\rremote: Counting objects: 16% (13/79)\u001b[K\rremote: Counting objects: 17% (14/79)\u001b[K\rremote: Counting objects: 18% (15/79)\u001b[K\rremote: Counting objects: 20% (16/79)\u001b[K\rremote: Counting objects: 21% (17/79)\u001b[K\rremote: Counting objects: 22% (18/79)\u001b[K\rremote: Counting objects: 24% (19/79)\u001b[K\rremote: Counting objects: 25% (20/79)\u001b[K\rremote: Counting objects: 26% (21/79)\u001b[K\rremote: Counting objects: 27% (22/79)\u001b[K\rremote: Counting objects: 29% (23/79)\u001b[K\rremote: Counting objects: 30% (24/79)\u001b[K\rremote: Counting objects: 31% (25/79)\u001b[K\rremote: Counting objects: 32% (26/79)\u001b[K\rremote: Counting objects: 34% (27/79)\u001b[K\rremote: Counting objects: 35% (28/79)\u001b[K\rremote: Counting objects: 36% (29/79)\u001b[K\rremote: Counting objects: 37% (30/79)\u001b[K\rremote: Counting objects: 39% (31/79)\u001b[K\rremote: Counting objects: 40% (32/79)\u001b[K\rremote: Counting objects: 41% (33/79)\u001b[K\rremote: Counting objects: 43% (34/79)\u001b[K\rremote: Counting objects: 44% (35/79)\u001b[K\rremote: Counting objects: 45% (36/79)\u001b[K\rremote: Counting objects: 46% (37/79)\u001b[K\rremote: Counting objects: 48% (38/79)\u001b[K\rremote: Counting objects: 49% (39/79)\u001b[K\rremote: Counting objects: 50% (40/79)\u001b[K\rremote: Counting objects: 51% (41/79)\u001b[K\rremote: Counting objects: 53% (42/79)\u001b[K\rremote: Counting objects: 54% (43/79)\u001b[K\rremote: Counting objects: 55% (44/79)\u001b[K\rremote: Counting objects: 56% (45/79)\u001b[K\rremote: Counting objects: 58% (46/79)\u001b[K\rremote: Counting objects: 59% (47/79)\u001b[K\rremote: Counting objects: 60% (48/79)\u001b[K\rremote: Counting objects: 62% (49/79)\u001b[K\rremote: Counting objects: 63% (50/79)\u001b[K\rremote: Counting objects: 64% (51/79)\u001b[K\rremote: Counting objects: 65% (52/79)\u001b[K\rremote: Counting objects: 67% (53/79)\u001b[K\rremote: Counting objects: 68% (54/79)\u001b[K\rremote: Counting objects: 69% (55/79)\u001b[K\rremote: Counting objects: 70% (56/79)\u001b[K\rremote: Counting objects: 72% (57/79)\u001b[K\rremote: Counting objects: 73% (58/79)\u001b[K\rremote: Counting objects: 74% (59/79)\u001b[K\rremote: Counting objects: 75% (60/79)\u001b[K\rremote: Counting objects: 77% (61/79)\u001b[K\rremote: Counting objects: 78% (62/79)\u001b[K\rremote: Counting objects: 79% (63/79)\u001b[K\rremote: Counting objects: 81% (64/79)\u001b[K\rremote: Counting objects: 82% (65/79)\u001b[K\rremote: Counting objects: 83% (66/79)\u001b[K\rremote: Counting objects: 84% (67/79)\u001b[K\rremote: Counting objects: 86% (68/79)\u001b[K\rremote: Counting objects: 87% (69/79)\u001b[K\rremote: Counting objects: 88% (70/79)\u001b[K\rremote: Counting objects: 89% (71/79)\u001b[K\rremote: Counting objects: 91% (72/79)\u001b[K\rremote: Counting objects: 92% (73/79)\u001b[K\rremote: Counting objects: 93% (74/79)\u001b[K\rremote: Counting objects: 94% (75/79)\u001b[K\rremote: Counting objects: 96% (76/79)\u001b[K\rremote: Counting objects: 97% (77/79)\u001b[K\rremote: Counting objects: 98% (78/79)\u001b[K\rremote: Counting objects: 100% (79/79)\u001b[K\rremote: Counting objects: 100% (79/79), done.\u001b[K\n", 51 | "remote: Compressing objects: 100% (58/58), done.\u001b[K\n", 52 | "remote: Total 79 (delta 29), reused 57 (delta 18), pack-reused 0\u001b[K\n", 53 | "Unpacking objects: 100% (79/79), done.\n" 54 | ], 55 | "name": "stdout" 56 | } 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "JMzOLu6Tjhqz", 63 | "colab_type": "code", 64 | "colab": { 65 | "base_uri": "https://localhost:8080/", 66 | "height": 34 67 | }, 68 | "outputId": "f2df8dc5-3bce-42c2-f03e-81bf10a0a8ab" 69 | }, 70 | "source": [ 71 | "%cd catr/" 72 | ], 73 | "execution_count": 4, 74 | "outputs": [ 75 | { 76 | "output_type": "stream", 77 | "text": [ 78 | "/content/catr\n" 79 | ], 80 | "name": "stdout" 81 | } 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "metadata": { 87 | "id": "ojT0Ln7GjayV", 88 | "colab_type": "code", 89 | "colab": { 90 | "base_uri": "https://localhost:8080/", 91 | "height": 697 92 | }, 93 | "outputId": "618844dd-0459-4d3d-b71b-d5ed11ec8bad" 94 | }, 95 | "source": [ 96 | "!pip install -r requirements.txt" 97 | ], 98 | "execution_count": 5, 99 | "outputs": [ 100 | { 101 | "output_type": "stream", 102 | "text": [ 103 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 1)) (1.5.1+cu101)\n", 104 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 2)) (0.6.1+cu101)\n", 105 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 3)) (1.18.5)\n", 106 | "Collecting transformers\n", 107 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)\n", 108 | "\r\u001b[K |▍ | 10kB 24.9MB/s eta 0:00:01\r\u001b[K |▉ | 20kB 3.3MB/s eta 0:00:01\r\u001b[K |█▎ | 30kB 4.4MB/s eta 0:00:01\r\u001b[K |█▊ | 40kB 4.6MB/s eta 0:00:01\r\u001b[K |██▏ | 51kB 3.8MB/s eta 0:00:01\r\u001b[K |██▋ | 61kB 4.3MB/s eta 0:00:01\r\u001b[K |███ | 71kB 4.7MB/s eta 0:00:01\r\u001b[K |███▍ | 81kB 4.9MB/s eta 0:00:01\r\u001b[K |███▉ | 92kB 5.3MB/s eta 0:00:01\r\u001b[K |████▎ | 102kB 5.3MB/s eta 0:00:01\r\u001b[K |████▊ | 112kB 5.3MB/s eta 0:00:01\r\u001b[K |█████▏ | 122kB 5.3MB/s eta 0:00:01\r\u001b[K |█████▌ | 133kB 5.3MB/s eta 0:00:01\r\u001b[K |██████ | 143kB 5.3MB/s eta 0:00:01\r\u001b[K |██████▍ | 153kB 5.3MB/s eta 0:00:01\r\u001b[K |██████▉ | 163kB 5.3MB/s eta 0:00:01\r\u001b[K |███████▎ | 174kB 5.3MB/s eta 0:00:01\r\u001b[K |███████▊ | 184kB 5.3MB/s eta 0:00:01\r\u001b[K |████████ | 194kB 5.3MB/s eta 0:00:01\r\u001b[K |████████▌ | 204kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████ | 215kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████▍ | 225kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████▉ | 235kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████▎ | 245kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████▋ | 256kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████ | 266kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████▌ | 276kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████ | 286kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████▍ | 296kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████▉ | 307kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████▏ | 317kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████▋ | 327kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████ | 337kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████▌ | 348kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████ | 358kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████▍ | 368kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 378kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 389kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████▋ | 399kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████ | 409kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████▌ | 419kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████ | 430kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 440kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 450kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 460kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████▋ | 471kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████ | 481kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 491kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 501kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 512kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 522kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▏ | 532kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▋ | 542kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 552kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 563kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████▉ | 573kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 583kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 593kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 604kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 614kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 624kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 634kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████▉ | 645kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 655kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████▊ | 665kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 675kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 686kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 696kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▍ | 706kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▉ | 716kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 727kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 737kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 747kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 757kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 768kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 778kB 5.3MB/s \n", 109 | "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 5)) (4.41.1)\n", 110 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->-r requirements.txt (line 1)) (0.16.0)\n", 111 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->-r requirements.txt (line 2)) (7.0.0)\n", 112 | "Collecting tokenizers==0.8.1.rc1\n", 113 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)\n", 114 | "\u001b[K |████████████████████████████████| 3.0MB 24.6MB/s \n", 115 | "\u001b[?25hRequirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (0.7)\n", 116 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (2.23.0)\n", 117 | "Collecting sacremoses\n", 118 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n", 119 | "\u001b[K |████████████████████████████████| 890kB 31.3MB/s \n", 120 | "\u001b[?25hCollecting sentencepiece!=0.1.92\n", 121 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", 122 | "\u001b[K |████████████████████████████████| 1.1MB 49.5MB/s \n", 123 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (20.4)\n", 124 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (2019.12.20)\n", 125 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (3.0.12)\n", 126 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2.10)\n", 127 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (1.24.3)\n", 128 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2020.6.20)\n", 129 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (3.0.4)\n", 130 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (1.15.0)\n", 131 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (7.1.2)\n", 132 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (0.16.0)\n", 133 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers->-r requirements.txt (line 4)) (2.4.7)\n", 134 | "Building wheels for collected packages: sacremoses\n", 135 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 136 | " Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=54b093ea3f7db7b5b98464934131d04f8580c06b274673e0cdf946240bc124f0\n", 137 | " Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n", 138 | "Successfully built sacremoses\n", 139 | "Installing collected packages: tokenizers, sacremoses, sentencepiece, transformers\n", 140 | "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc1 transformers-3.0.2\n" 141 | ], 142 | "name": "stdout" 143 | } 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "KesOqe40jdak", 150 | "colab_type": "code", 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 176 154 | }, 155 | "outputId": "54e430ea-f037-4541-d0d4-ac71ee457ee3" 156 | }, 157 | "source": [ 158 | "!python predict.py --path .github/cake.png" 159 | ], 160 | "execution_count": 6, 161 | "outputs": [ 162 | { 163 | "output_type": "stream", 164 | "text": [ 165 | "2020-07-30 11:50:09.260991: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", 166 | "Downloading: \"https://github.com/saahiluppal/catr/archive/master.zip\" to /root/.cache/torch/hub/master.zip\n", 167 | "Downloading: \"https://download.pytorch.org/models/resnet101-5d3b4d8f.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth\n", 168 | "100% 170M/170M [00:00<00:00, 227MB/s]\n", 169 | "Downloading: \"https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth\" to /root/.cache/torch/hub/checkpoints/weights_9348032.pth\n", 170 | "100% 322M/322M [00:07<00:00, 45.0MB/s]\n", 171 | "Downloading: 100% 232k/232k [00:00<00:00, 916kB/s]\n", 172 | "A person cutting a cake with a knife.\n" 173 | ], 174 | "name": "stdout" 175 | } 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "metadata": { 181 | "id": "8VAr2dFajs69", 182 | "colab_type": "code", 183 | "colab": { 184 | "base_uri": "https://localhost:8080/", 185 | "height": 89 186 | }, 187 | "outputId": "7c985740-9d46-43a9-bcaf-822a6952b268" 188 | }, 189 | "source": [ 190 | "!python predict.py --path .github/girl.png" 191 | ], 192 | "execution_count": 7, 193 | "outputs": [ 194 | { 195 | "output_type": "stream", 196 | "text": [ 197 | "2020-07-30 11:50:48.364873: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", 198 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n", 199 | "A woman sitting on a curb holding a pink umbrella.\n" 200 | ], 201 | "name": "stdout" 202 | } 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "metadata": { 208 | "id": "_evy1rElj3JK", 209 | "colab_type": "code", 210 | "colab": { 211 | "base_uri": "https://localhost:8080/", 212 | "height": 89 213 | }, 214 | "outputId": "37eeaf43-3d7c-41db-bb59-287bd7632da9" 215 | }, 216 | "source": [ 217 | "!python predict.py --path .github/office.png" 218 | ], 219 | "execution_count": 8, 220 | "outputs": [ 221 | { 222 | "output_type": "stream", 223 | "text": [ 224 | "2020-07-30 11:51:07.946459: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", 225 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n", 226 | "A group of people sitting at a table with laptops.\n" 227 | ], 228 | "name": "stdout" 229 | } 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "metadata": { 235 | "id": "5uqoo0NCj6h-", 236 | "colab_type": "code", 237 | "colab": { 238 | "base_uri": "https://localhost:8080/", 239 | "height": 89 240 | }, 241 | "outputId": "89df63ee-77e6-4acc-e9c3-9e6e05a9f517" 242 | }, 243 | "source": [ 244 | "!python predict.py --path .github/horse.png" 245 | ], 246 | "execution_count": 9, 247 | "outputs": [ 248 | { 249 | "output_type": "stream", 250 | "text": [ 251 | "2020-07-30 11:51:26.572680: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", 252 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n", 253 | "A man riding on the back of a brown horse.\n" 254 | ], 255 | "name": "stdout" 256 | } 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "CCnUN_H4j7x9", 263 | "colab_type": "code", 264 | "colab": { 265 | "base_uri": "https://localhost:8080/", 266 | "height": 89 267 | }, 268 | "outputId": "c095a09c-eebd-41ce-b7a2-df9af8100dc0" 269 | }, 270 | "source": [ 271 | "!python predict.py --path .github/airplane.png" 272 | ], 273 | "execution_count": 10, 274 | "outputs": [ 275 | { 276 | "output_type": "stream", 277 | "text": [ 278 | "2020-07-30 11:51:44.833782: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", 279 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n", 280 | "A large jetliner sitting on top of an airport runway.\n" 281 | ], 282 | "name": "stdout" 283 | } 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "metadata": { 289 | "id": "x8YuZRMokBqT", 290 | "colab_type": "code", 291 | "colab": {} 292 | }, 293 | "source": [ 294 | "" 295 | ], 296 | "execution_count": null, 297 | "outputs": [] 298 | } 299 | ] 300 | } -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self): 3 | 4 | # Learning Rates 5 | self.lr_backbone = 1e-5 6 | self.lr = 1e-4 7 | 8 | # Epochs 9 | self.epochs = 30 10 | self.lr_drop = 20 11 | self.start_epoch = 0 12 | self.weight_decay = 1e-4 13 | 14 | # Backbone 15 | self.backbone = 'resnet101' 16 | self.position_embedding = 'sine' 17 | self.dilation = True 18 | 19 | # Basic 20 | self.device = 'cuda' 21 | self.seed = 42 22 | self.batch_size = 32 23 | self.num_workers = 8 24 | self.checkpoint = './checkpoint.pth' 25 | self.clip_max_norm = 0.1 26 | 27 | # Transformer 28 | self.hidden_dim = 256 29 | self.pad_token_id = 0 30 | self.max_position_embeddings = 128 31 | self.layer_norm_eps = 1e-12 32 | self.dropout = 0.1 33 | self.vocab_size = 30522 34 | 35 | self.enc_layers = 6 36 | self.dec_layers = 6 37 | self.dim_feedforward = 2048 38 | self.nheads = 8 39 | self.pre_norm = True 40 | 41 | # Dataset 42 | self.dir = '../coco' 43 | self.limit = -1 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms.functional as TF 3 | import torchvision as tv 4 | 5 | from PIL import Image 6 | import numpy as np 7 | import random 8 | import os 9 | 10 | from transformers import BertTokenizer 11 | 12 | from .utils import nested_tensor_from_tensor_list, read_json 13 | 14 | MAX_DIM = 299 15 | 16 | 17 | def under_max(image): 18 | if image.mode != 'RGB': 19 | image = image.convert("RGB") 20 | 21 | shape = np.array(image.size, dtype=np.float) 22 | long_dim = max(shape) 23 | scale = MAX_DIM / long_dim 24 | 25 | new_shape = (shape * scale).astype(int) 26 | image = image.resize(new_shape) 27 | 28 | return image 29 | 30 | 31 | class RandomRotation: 32 | def __init__(self, angles=[0, 90, 180, 270]): 33 | self.angles = angles 34 | 35 | def __call__(self, x): 36 | angle = random.choice(self.angles) 37 | return TF.rotate(x, angle, expand=True) 38 | 39 | 40 | train_transform = tv.transforms.Compose([ 41 | RandomRotation(), 42 | tv.transforms.Lambda(under_max), 43 | tv.transforms.ColorJitter(brightness=[0.5, 1.3], contrast=[ 44 | 0.8, 1.5], saturation=[0.2, 1.5]), 45 | tv.transforms.RandomHorizontalFlip(), 46 | tv.transforms.ToTensor(), 47 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 48 | ]) 49 | 50 | val_transform = tv.transforms.Compose([ 51 | tv.transforms.Lambda(under_max), 52 | tv.transforms.ToTensor(), 53 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 54 | ]) 55 | 56 | 57 | class CocoCaption(Dataset): 58 | def __init__(self, root, ann, max_length, limit, transform=train_transform, mode='training'): 59 | super().__init__() 60 | 61 | self.root = root 62 | self.transform = transform 63 | self.annot = [(self._process(val['image_id']), val['caption']) 64 | for val in ann['annotations']] 65 | if mode == 'validation': 66 | self.annot = self.annot 67 | if mode == 'training': 68 | self.annot = self.annot[: limit] 69 | 70 | self.tokenizer = BertTokenizer.from_pretrained( 71 | 'bert-base-uncased', do_lower=True) 72 | self.max_length = max_length + 1 73 | 74 | def _process(self, image_id): 75 | val = str(image_id).zfill(12) 76 | return val + '.jpg' 77 | 78 | def __len__(self): 79 | return len(self.annot) 80 | 81 | def __getitem__(self, idx): 82 | image_id, caption = self.annot[idx] 83 | image = Image.open(os.path.join(self.root, image_id)) 84 | 85 | if self.transform: 86 | image = self.transform(image) 87 | image = nested_tensor_from_tensor_list(image.unsqueeze(0)) 88 | 89 | caption_encoded = self.tokenizer.encode_plus( 90 | caption, max_length=self.max_length, pad_to_max_length=True, return_attention_mask=True, return_token_type_ids=False, truncation=True) 91 | 92 | caption = np.array(caption_encoded['input_ids']) 93 | cap_mask = ( 94 | 1 - np.array(caption_encoded['attention_mask'])).astype(bool) 95 | 96 | return image.tensors.squeeze(0), image.mask.squeeze(0), caption, cap_mask 97 | 98 | 99 | def build_dataset(config, mode='training'): 100 | if mode == 'training': 101 | train_dir = os.path.join(config.dir, 'train2017') 102 | train_file = os.path.join( 103 | config.dir, 'annotations', 'captions_train2017.json') 104 | data = CocoCaption(train_dir, read_json( 105 | train_file), max_length=config.max_position_embeddings, limit=config.limit, transform=train_transform, mode='training') 106 | return data 107 | 108 | elif mode == 'validation': 109 | val_dir = os.path.join(config.dir, 'val2017') 110 | val_file = os.path.join( 111 | config.dir, 'annotations', 'captions_val2017.json') 112 | data = CocoCaption(val_dir, read_json( 113 | val_file), max_length=config.max_position_embeddings, limit=config.limit, transform=val_transform, mode='validation') 114 | return data 115 | 116 | else: 117 | raise NotImplementedError(f"{mode} not supported") 118 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List 3 | from torch import Tensor 4 | 5 | import json 6 | import os 7 | 8 | MAX_DIM = 299 9 | 10 | def read_json(file_name): 11 | with open(file_name) as handle: 12 | out = json.load(handle) 13 | return out 14 | 15 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 16 | # TODO make this more general 17 | if tensor_list[0].ndim == 3: 18 | # TODO make it support different-sized images 19 | max_size = [3, MAX_DIM, MAX_DIM] 20 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 21 | batch_shape = [len(tensor_list)] + max_size 22 | b, c, h, w = batch_shape 23 | dtype = tensor_list[0].dtype 24 | device = tensor_list[0].device 25 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 26 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 27 | for img, pad_img, m in zip(tensor_list, tensor, mask): 28 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 29 | m[: img.shape[1], :img.shape[2]] = False 30 | else: 31 | raise ValueError('not supported') 32 | return NestedTensor(tensor, mask) 33 | 34 | 35 | class NestedTensor(object): 36 | def __init__(self, tensors, mask: Optional[Tensor]): 37 | self.tensors = tensors 38 | self.mask = mask 39 | 40 | def to(self, device): 41 | # type: (Device) -> NestedTensor # noqa 42 | cast_tensor = self.tensors.to(device) 43 | mask = self.mask 44 | if mask is not None: 45 | assert mask is not None 46 | cast_mask = mask.to(device) 47 | else: 48 | cast_mask = None 49 | return NestedTensor(cast_tensor, cast_mask) 50 | 51 | def decompose(self): 52 | return self.tensors, self.mask 53 | 54 | def __repr__(self): 55 | return str(self.tensors) 56 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | import math 5 | import sys 6 | import tqdm 7 | 8 | from models import utils 9 | 10 | 11 | def train_one_epoch(model, criterion, data_loader, 12 | optimizer, device, epoch, max_norm): 13 | model.train() 14 | criterion.train() 15 | 16 | epoch_loss = 0.0 17 | total = len(data_loader) 18 | 19 | with tqdm.tqdm(total=total) as pbar: 20 | for images, masks, caps, cap_masks in data_loader: 21 | samples = utils.NestedTensor(images, masks).to(device) 22 | caps = caps.to(device) 23 | cap_masks = cap_masks.to(device) 24 | 25 | outputs = model(samples, caps[:, :-1], cap_masks[:, :-1]) 26 | loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:]) 27 | loss_value = loss.item() 28 | epoch_loss += loss_value 29 | 30 | if not math.isfinite(loss_value): 31 | print(f'Loss is {loss_value}, stopping training') 32 | sys.exit(1) 33 | 34 | optimizer.zero_grad() 35 | loss.backward() 36 | if max_norm > 0: 37 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 38 | optimizer.step() 39 | 40 | pbar.update(1) 41 | 42 | return epoch_loss / total 43 | 44 | @torch.no_grad() 45 | def evaluate(model, criterion, data_loader, device): 46 | model.eval() 47 | criterion.eval() 48 | 49 | validation_loss = 0.0 50 | total = len(data_loader) 51 | 52 | with tqdm.tqdm(total=total) as pbar: 53 | for images, masks, caps, cap_masks in data_loader: 54 | samples = utils.NestedTensor(images, masks).to(device) 55 | caps = caps.to(device) 56 | cap_masks = cap_masks.to(device) 57 | 58 | outputs = model(samples, caps[:, :-1], cap_masks[:, :-1]) 59 | loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:]) 60 | 61 | validation_loss += loss.item() 62 | 63 | pbar.update(1) 64 | 65 | return validation_loss / total -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import numpy as np 5 | import time 6 | import sys 7 | import os 8 | 9 | from models import utils, caption 10 | from datasets import coco 11 | from configuration import Config 12 | from engine import train_one_epoch, evaluate 13 | 14 | 15 | def finetune(config): 16 | device = torch.device(config.device) 17 | print(f'Initializing Device: {device}') 18 | 19 | seed = config.seed + utils.get_rank() 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | 23 | model, criterion = caption.build_model(config) 24 | checkpoint = torch.hub.load_state_dict_from_url( 25 | url="https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth", 26 | map_location=device 27 | ) 28 | model.to(device) 29 | model.load_state_dict(checkpoint['model']) 30 | 31 | config.lr = 1e-5 32 | config.epochs = 10 33 | config.lr_drop = 8 34 | 35 | n_parameters = sum(p.numel() 36 | for p in model.parameters() if p.requires_grad) 37 | print(f"Number of params: {n_parameters}") 38 | 39 | param_dicts = [ 40 | {"params": [p for n, p in model.named_parameters( 41 | ) if "backbone" not in n and p.requires_grad]}, 42 | { 43 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 44 | "lr": config.lr_backbone, 45 | }, 46 | ] 47 | optimizer = torch.optim.AdamW( 48 | param_dicts, lr=config.lr, weight_decay=config.weight_decay) 49 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop) 50 | 51 | dataset_train = coco.build_dataset(config, mode='training') 52 | dataset_val = coco.build_dataset(config, mode='validation') 53 | print(f"Train: {len(dataset_train)}") 54 | print(f"Valid: {len(dataset_val)}") 55 | 56 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 57 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 58 | 59 | batch_sampler_train = torch.utils.data.BatchSampler( 60 | sampler_train, config.batch_size, drop_last=True 61 | ) 62 | 63 | data_loader_train = DataLoader( 64 | dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers) 65 | data_loader_val = DataLoader(dataset_val, config.batch_size, 66 | sampler=sampler_val, drop_last=False, num_workers=config.num_workers) 67 | 68 | if os.path.exists(config.checkpoint): 69 | print("Loading Checkpoint...") 70 | checkpoint = torch.load(config.checkpoint, map_location='cpu') 71 | model.load_state_dict(checkpoint['model']) 72 | optimizer.load_state_dict(checkpoint['optimizer']) 73 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 74 | config.start_epoch = checkpoint['epoch'] + 1 75 | 76 | print("Start Training..") 77 | for epoch in range(config.start_epoch, config.epochs): 78 | print(f"Epoch: {epoch}") 79 | epoch_loss = train_one_epoch( 80 | model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm) 81 | lr_scheduler.step() 82 | print(f"Training Loss: {epoch_loss}") 83 | 84 | torch.save({ 85 | 'model': model.state_dict(), 86 | 'optimizer': optimizer.state_dict(), 87 | 'lr_scheduler': lr_scheduler.state_dict(), 88 | 'epoch': epoch, 89 | }, config.checkpoint) 90 | 91 | validation_loss = evaluate(model, criterion, data_loader_val, device) 92 | print(f"Validation Loss: {validation_loss}") 93 | 94 | print() 95 | 96 | 97 | if __name__ == "__main__": 98 | config = Config() 99 | finetune(config) 100 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models import caption 4 | from configuration import Config 5 | 6 | dependencies = ['torch', 'torchvision'] 7 | 8 | def v1(pretrained=False): 9 | config = Config() 10 | model, _ = caption.build_model(config) 11 | 12 | if pretrained: 13 | checkpoint = torch.hub.load_state_dict_from_url( 14 | url='https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth', 15 | map_location='cpu' 16 | ) 17 | model.load_state_dict(checkpoint['model']) 18 | 19 | return model 20 | 21 | def v2(pretrained=False): 22 | config = Config() 23 | model, _ = caption.build_model(config) 24 | 25 | if pretrained: 26 | checkpoint = torch.hub.load_state_dict_from_url( 27 | url='https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth', 28 | map_location='cpu' 29 | ) 30 | model.load_state_dict(checkpoint['model']) 31 | 32 | return model 33 | 34 | def v3(pretrained=False): 35 | config = Config() 36 | model, _ = caption.build_model(config) 37 | 38 | if pretrained: 39 | checkpoint = torch.hub.load_state_dict_from_url( 40 | url='https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth', 41 | map_location='cpu' 42 | ) 43 | model.load_state_dict(checkpoint['model']) 44 | 45 | return model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import numpy as np 5 | import time 6 | import sys 7 | import os 8 | 9 | from models import utils, caption 10 | from datasets import coco 11 | from configuration import Config 12 | from engine import train_one_epoch, evaluate 13 | 14 | 15 | def main(config): 16 | device = torch.device(config.device) 17 | print(f'Initializing Device: {device}') 18 | 19 | seed = config.seed + utils.get_rank() 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | 23 | model, criterion = caption.build_model(config) 24 | model.to(device) 25 | 26 | n_parameters = sum(p.numel() 27 | for p in model.parameters() if p.requires_grad) 28 | print(f"Number of params: {n_parameters}") 29 | 30 | param_dicts = [ 31 | {"params": [p for n, p in model.named_parameters( 32 | ) if "backbone" not in n and p.requires_grad]}, 33 | { 34 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 35 | "lr": config.lr_backbone, 36 | }, 37 | ] 38 | optimizer = torch.optim.AdamW( 39 | param_dicts, lr=config.lr, weight_decay=config.weight_decay) 40 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop) 41 | 42 | dataset_train = coco.build_dataset(config, mode='training') 43 | dataset_val = coco.build_dataset(config, mode='validation') 44 | print(f"Train: {len(dataset_train)}") 45 | print(f"Valid: {len(dataset_val)}") 46 | 47 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 48 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 49 | 50 | batch_sampler_train = torch.utils.data.BatchSampler( 51 | sampler_train, config.batch_size, drop_last=True 52 | ) 53 | 54 | data_loader_train = DataLoader( 55 | dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers) 56 | data_loader_val = DataLoader(dataset_val, config.batch_size, 57 | sampler=sampler_val, drop_last=False, num_workers=config.num_workers) 58 | 59 | if os.path.exists(config.checkpoint): 60 | print("Loading Checkpoint...") 61 | checkpoint = torch.load(config.checkpoint, map_location='cpu') 62 | model.load_state_dict(checkpoint['model']) 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 65 | config.start_epoch = checkpoint['epoch'] + 1 66 | 67 | print("Start Training..") 68 | for epoch in range(config.start_epoch, config.epochs): 69 | print(f"Epoch: {epoch}") 70 | epoch_loss = train_one_epoch( 71 | model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm) 72 | lr_scheduler.step() 73 | print(f"Training Loss: {epoch_loss}") 74 | 75 | torch.save({ 76 | 'model': model.state_dict(), 77 | 'optimizer': optimizer.state_dict(), 78 | 'lr_scheduler': lr_scheduler.state_dict(), 79 | 'epoch': epoch, 80 | }, config.checkpoint) 81 | 82 | validation_loss = evaluate(model, criterion, data_loader_val, device) 83 | print(f"Validation Loss: {validation_loss}") 84 | 85 | print() 86 | 87 | 88 | if __name__ == "__main__": 89 | config = Config() 90 | main(config) 91 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/models/__init__.py -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torch import nn 8 | from torchvision.models._utils import IntermediateLayerGetter 9 | from typing import Dict, List 10 | 11 | from .utils import NestedTensor, is_main_process 12 | 13 | from .position_encoding import build_position_encoding 14 | 15 | 16 | class FrozenBatchNorm2d(torch.nn.Module): 17 | """ 18 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 19 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 20 | without which any other models than torchvision.models.resnet[18,34,50,101] 21 | produce nans. 22 | """ 23 | 24 | def __init__(self, n): 25 | super(FrozenBatchNorm2d, self).__init__() 26 | self.register_buffer("weight", torch.ones(n)) 27 | self.register_buffer("bias", torch.zeros(n)) 28 | self.register_buffer("running_mean", torch.zeros(n)) 29 | self.register_buffer("running_var", torch.ones(n)) 30 | 31 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 32 | missing_keys, unexpected_keys, error_msgs): 33 | num_batches_tracked_key = prefix + 'num_batches_tracked' 34 | if num_batches_tracked_key in state_dict: 35 | del state_dict[num_batches_tracked_key] 36 | 37 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 38 | state_dict, prefix, local_metadata, strict, 39 | missing_keys, unexpected_keys, error_msgs) 40 | 41 | def forward(self, x): 42 | # move reshapes to the beginning 43 | # to make it fuser-friendly 44 | w = self.weight.reshape(1, -1, 1, 1) 45 | b = self.bias.reshape(1, -1, 1, 1) 46 | rv = self.running_var.reshape(1, -1, 1, 1) 47 | rm = self.running_mean.reshape(1, -1, 1, 1) 48 | eps = 1e-5 49 | scale = w * (rv + eps).rsqrt() 50 | bias = b - rm * scale 51 | return x * scale + bias 52 | 53 | 54 | class BackboneBase(nn.Module): 55 | 56 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 57 | super().__init__() 58 | for name, parameter in backbone.named_parameters(): 59 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 60 | parameter.requires_grad_(False) 61 | if return_interm_layers: 62 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 63 | else: 64 | return_layers = {'layer4': "0"} 65 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 66 | self.num_channels = num_channels 67 | 68 | def forward(self, tensor_list: NestedTensor): 69 | xs = self.body(tensor_list.tensors) 70 | out: Dict[str, NestedTensor] = {} 71 | for name, x in xs.items(): 72 | m = tensor_list.mask 73 | assert m is not None 74 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 75 | out[name] = NestedTensor(x, mask) 76 | return out 77 | 78 | 79 | class Backbone(BackboneBase): 80 | """ResNet backbone with frozen BatchNorm.""" 81 | def __init__(self, name: str, 82 | train_backbone: bool, 83 | return_interm_layers: bool, 84 | dilation: bool): 85 | backbone = getattr(torchvision.models, name)( 86 | replace_stride_with_dilation=[False, False, dilation], 87 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 88 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 89 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 90 | 91 | 92 | class Joiner(nn.Sequential): 93 | def __init__(self, backbone, position_embedding): 94 | super().__init__(backbone, position_embedding) 95 | 96 | def forward(self, tensor_list: NestedTensor): 97 | xs = self[0](tensor_list) 98 | out: List[NestedTensor] = [] 99 | pos = [] 100 | for name, x in xs.items(): 101 | out.append(x) 102 | # position encoding 103 | pos.append(self[1](x).to(x.tensors.dtype)) 104 | 105 | return out, pos 106 | 107 | 108 | def build_backbone(config): 109 | position_embedding = build_position_encoding(config) 110 | train_backbone = config.lr_backbone > 0 111 | return_interm_layers = False 112 | backbone = Backbone(config.backbone, train_backbone, return_interm_layers, config.dilation) 113 | model = Joiner(backbone, position_embedding) 114 | model.num_channels = backbone.num_channels 115 | return model -------------------------------------------------------------------------------- /models/caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import NestedTensor, nested_tensor_from_tensor_list 6 | from .backbone import build_backbone 7 | from .transformer import build_transformer 8 | 9 | 10 | class Caption(nn.Module): 11 | def __init__(self, backbone, transformer, hidden_dim, vocab_size): 12 | super().__init__() 13 | self.backbone = backbone 14 | self.input_proj = nn.Conv2d( 15 | backbone.num_channels, hidden_dim, kernel_size=1) 16 | self.transformer = transformer 17 | self.mlp = MLP(hidden_dim, 512, vocab_size, 3) 18 | 19 | def forward(self, samples, target, target_mask): 20 | if not isinstance(samples, NestedTensor): 21 | samples = nested_tensor_from_tensor_list(samples) 22 | 23 | features, pos = self.backbone(samples) 24 | src, mask = features[-1].decompose() 25 | 26 | assert mask is not None 27 | 28 | hs = self.transformer(self.input_proj(src), mask, 29 | pos[-1], target, target_mask) 30 | out = self.mlp(hs.permute(1, 0, 2)) 31 | return out 32 | 33 | 34 | class MLP(nn.Module): 35 | """ Very simple multi-layer perceptron (also called FFN)""" 36 | 37 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 38 | super().__init__() 39 | self.num_layers = num_layers 40 | h = [hidden_dim] * (num_layers - 1) 41 | self.layers = nn.ModuleList(nn.Linear(n, k) 42 | for n, k in zip([input_dim] + h, h + [output_dim])) 43 | 44 | def forward(self, x): 45 | for i, layer in enumerate(self.layers): 46 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 47 | return x 48 | 49 | 50 | def build_model(config): 51 | backbone = build_backbone(config) 52 | transformer = build_transformer(config) 53 | 54 | model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size) 55 | criterion = torch.nn.CrossEntropyLoss() 56 | 57 | return model, criterion -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | import torch 4 | from torch import nn 5 | 6 | from .utils import NestedTensor 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 15 | super().__init__() 16 | self.num_pos_feats = num_pos_feats 17 | self.temperature = temperature 18 | self.normalize = normalize 19 | if scale is not None and normalize is False: 20 | raise ValueError("normalize should be True if scale is passed") 21 | if scale is None: 22 | scale = 2 * math.pi 23 | self.scale = scale 24 | 25 | def forward(self, tensor_list: NestedTensor): 26 | x = tensor_list.tensors 27 | mask = tensor_list.mask 28 | assert mask is not None 29 | not_mask = ~mask 30 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 31 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 32 | if self.normalize: 33 | eps = 1e-6 34 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 35 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 36 | 37 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 38 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 39 | 40 | pos_x = x_embed[:, :, :, None] / dim_t 41 | pos_y = y_embed[:, :, :, None] / dim_t 42 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 45 | return pos 46 | 47 | 48 | class PositionEmbeddingLearned(nn.Module): 49 | """ 50 | Absolute pos embedding, learned. 51 | """ 52 | def __init__(self, num_pos_feats=256): 53 | super().__init__() 54 | self.row_embed = nn.Embedding(50, num_pos_feats) 55 | self.col_embed = nn.Embedding(50, num_pos_feats) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.uniform_(self.row_embed.weight) 60 | nn.init.uniform_(self.col_embed.weight) 61 | 62 | def forward(self, tensor_list: NestedTensor): 63 | x = tensor_list.tensors 64 | h, w = x.shape[-2:] 65 | i = torch.arange(w, device=x.device) 66 | j = torch.arange(h, device=x.device) 67 | x_emb = self.col_embed(i) 68 | y_emb = self.row_embed(j) 69 | pos = torch.cat([ 70 | x_emb.unsqueeze(0).repeat(h, 1, 1), 71 | y_emb.unsqueeze(1).repeat(1, w, 1), 72 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 73 | return pos 74 | 75 | 76 | def build_position_encoding(config): 77 | N_steps = config.hidden_dim // 2 78 | if config.position_embedding in ('v2', 'sine'): 79 | # TODO find a better way of exposing other arguments 80 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 81 | elif config.position_embedding in ('v3', 'learned'): 82 | position_embedding = PositionEmbeddingLearned(N_steps) 83 | else: 84 | raise ValueError(f"not supported {config.position_embedding}") 85 | 86 | return position_embedding -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import copy 3 | from typing import Optional, List 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, Tensor 8 | 9 | 10 | class Transformer(nn.Module): 11 | 12 | def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6, 13 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 14 | activation="relu", normalize_before=False, 15 | return_intermediate_dec=False): 16 | super().__init__() 17 | 18 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 19 | dropout, activation, normalize_before) 20 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 21 | self.encoder = TransformerEncoder( 22 | encoder_layer, num_encoder_layers, encoder_norm) 23 | 24 | self.embeddings = DecoderEmbeddings(config) 25 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 26 | dropout, activation, normalize_before) 27 | decoder_norm = nn.LayerNorm(d_model) 28 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 29 | return_intermediate=return_intermediate_dec) 30 | 31 | self._reset_parameters() 32 | 33 | self.d_model = d_model 34 | self.nhead = nhead 35 | 36 | def _reset_parameters(self): 37 | for p in self.parameters(): 38 | if p.dim() > 1: 39 | nn.init.xavier_uniform_(p) 40 | 41 | def forward(self, src, mask, pos_embed, tgt, tgt_mask): 42 | # flatten NxCxHxW to HWxNxC 43 | bs, c, h, w = src.shape 44 | src = src.flatten(2).permute(2, 0, 1) 45 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 46 | mask = mask.flatten(1) 47 | 48 | tgt = self.embeddings(tgt).permute(1, 0, 2) 49 | query_embed = self.embeddings.position_embeddings.weight.unsqueeze(1) 50 | query_embed = query_embed.repeat(1, bs, 1) 51 | 52 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 53 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, tgt_key_padding_mask=tgt_mask, 54 | pos=pos_embed, query_pos=query_embed, 55 | tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device)) 56 | 57 | return hs 58 | 59 | 60 | class TransformerEncoder(nn.Module): 61 | 62 | def __init__(self, encoder_layer, num_layers, norm=None): 63 | super().__init__() 64 | self.layers = _get_clones(encoder_layer, num_layers) 65 | self.num_layers = num_layers 66 | self.norm = norm 67 | 68 | def forward(self, src, 69 | mask: Optional[Tensor] = None, 70 | src_key_padding_mask: Optional[Tensor] = None, 71 | pos: Optional[Tensor] = None): 72 | output = src 73 | 74 | for layer in self.layers: 75 | output = layer(output, src_mask=mask, 76 | src_key_padding_mask=src_key_padding_mask, pos=pos) 77 | 78 | if self.norm is not None: 79 | output = self.norm(output) 80 | 81 | return output 82 | 83 | 84 | class TransformerDecoder(nn.Module): 85 | 86 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 87 | super().__init__() 88 | self.layers = _get_clones(decoder_layer, num_layers) 89 | self.num_layers = num_layers 90 | self.norm = norm 91 | self.return_intermediate = return_intermediate 92 | 93 | def forward(self, tgt, memory, 94 | tgt_mask: Optional[Tensor] = None, 95 | memory_mask: Optional[Tensor] = None, 96 | tgt_key_padding_mask: Optional[Tensor] = None, 97 | memory_key_padding_mask: Optional[Tensor] = None, 98 | pos: Optional[Tensor] = None, 99 | query_pos: Optional[Tensor] = None): 100 | output = tgt 101 | 102 | intermediate = [] 103 | 104 | for layer in self.layers: 105 | output = layer(output, memory, tgt_mask=tgt_mask, 106 | memory_mask=memory_mask, 107 | tgt_key_padding_mask=tgt_key_padding_mask, 108 | memory_key_padding_mask=memory_key_padding_mask, 109 | pos=pos, query_pos=query_pos) 110 | if self.return_intermediate: 111 | intermediate.append(self.norm(output)) 112 | 113 | if self.norm is not None: 114 | output = self.norm(output) 115 | if self.return_intermediate: 116 | intermediate.pop() 117 | intermediate.append(output) 118 | 119 | if self.return_intermediate: 120 | return torch.stack(intermediate) 121 | 122 | return output 123 | 124 | 125 | class TransformerEncoderLayer(nn.Module): 126 | 127 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 128 | activation="relu", normalize_before=False): 129 | super().__init__() 130 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 131 | # Implementation of Feedforward model 132 | self.linear1 = nn.Linear(d_model, dim_feedforward) 133 | self.dropout = nn.Dropout(dropout) 134 | self.linear2 = nn.Linear(dim_feedforward, d_model) 135 | 136 | self.norm1 = nn.LayerNorm(d_model) 137 | self.norm2 = nn.LayerNorm(d_model) 138 | self.dropout1 = nn.Dropout(dropout) 139 | self.dropout2 = nn.Dropout(dropout) 140 | 141 | self.activation = _get_activation_fn(activation) 142 | self.normalize_before = normalize_before 143 | 144 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 145 | return tensor if pos is None else tensor + pos 146 | 147 | def forward_post(self, 148 | src, 149 | src_mask: Optional[Tensor] = None, 150 | src_key_padding_mask: Optional[Tensor] = None, 151 | pos: Optional[Tensor] = None): 152 | q = k = self.with_pos_embed(src, pos) 153 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 154 | key_padding_mask=src_key_padding_mask)[0] 155 | src = src + self.dropout1(src2) 156 | src = self.norm1(src) 157 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 158 | src = src + self.dropout2(src2) 159 | src = self.norm2(src) 160 | return src 161 | 162 | def forward_pre(self, src, 163 | src_mask: Optional[Tensor] = None, 164 | src_key_padding_mask: Optional[Tensor] = None, 165 | pos: Optional[Tensor] = None): 166 | src2 = self.norm1(src) 167 | q = k = self.with_pos_embed(src2, pos) 168 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 169 | key_padding_mask=src_key_padding_mask)[0] 170 | src = src + self.dropout1(src2) 171 | src2 = self.norm2(src) 172 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 173 | src = src + self.dropout2(src2) 174 | return src 175 | 176 | def forward(self, src, 177 | src_mask: Optional[Tensor] = None, 178 | src_key_padding_mask: Optional[Tensor] = None, 179 | pos: Optional[Tensor] = None): 180 | if self.normalize_before: 181 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 182 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 183 | 184 | 185 | class TransformerDecoderLayer(nn.Module): 186 | 187 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 188 | activation="relu", normalize_before=False): 189 | super().__init__() 190 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 191 | self.multihead_attn = nn.MultiheadAttention( 192 | d_model, nhead, dropout=dropout) 193 | # Implementation of Feedforward model 194 | self.linear1 = nn.Linear(d_model, dim_feedforward) 195 | self.dropout = nn.Dropout(dropout) 196 | self.linear2 = nn.Linear(dim_feedforward, d_model) 197 | 198 | self.norm1 = nn.LayerNorm(d_model) 199 | self.norm2 = nn.LayerNorm(d_model) 200 | self.norm3 = nn.LayerNorm(d_model) 201 | self.dropout1 = nn.Dropout(dropout) 202 | self.dropout2 = nn.Dropout(dropout) 203 | self.dropout3 = nn.Dropout(dropout) 204 | 205 | self.activation = _get_activation_fn(activation) 206 | self.normalize_before = normalize_before 207 | 208 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 209 | return tensor if pos is None else tensor + pos 210 | 211 | def forward_post(self, tgt, memory, 212 | tgt_mask: Optional[Tensor] = None, 213 | memory_mask: Optional[Tensor] = None, 214 | tgt_key_padding_mask: Optional[Tensor] = None, 215 | memory_key_padding_mask: Optional[Tensor] = None, 216 | pos: Optional[Tensor] = None, 217 | query_pos: Optional[Tensor] = None): 218 | q = k = self.with_pos_embed(tgt, query_pos) 219 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 220 | key_padding_mask=tgt_key_padding_mask)[0] 221 | tgt = tgt + self.dropout1(tgt2) 222 | tgt = self.norm1(tgt) 223 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 224 | key=self.with_pos_embed(memory, pos), 225 | value=memory, attn_mask=memory_mask, 226 | key_padding_mask=memory_key_padding_mask)[0] 227 | tgt = tgt + self.dropout2(tgt2) 228 | tgt = self.norm2(tgt) 229 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 230 | tgt = tgt + self.dropout3(tgt2) 231 | tgt = self.norm3(tgt) 232 | return tgt 233 | 234 | def forward_pre(self, tgt, memory, 235 | tgt_mask: Optional[Tensor] = None, 236 | memory_mask: Optional[Tensor] = None, 237 | tgt_key_padding_mask: Optional[Tensor] = None, 238 | memory_key_padding_mask: Optional[Tensor] = None, 239 | pos: Optional[Tensor] = None, 240 | query_pos: Optional[Tensor] = None): 241 | tgt2 = self.norm1(tgt) 242 | q = k = self.with_pos_embed(tgt2, query_pos) 243 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 244 | key_padding_mask=tgt_key_padding_mask)[0] 245 | tgt = tgt + self.dropout1(tgt2) 246 | tgt2 = self.norm2(tgt) 247 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 248 | key=self.with_pos_embed(memory, pos), 249 | value=memory, attn_mask=memory_mask, 250 | key_padding_mask=memory_key_padding_mask)[0] 251 | tgt = tgt + self.dropout2(tgt2) 252 | tgt2 = self.norm3(tgt) 253 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 254 | tgt = tgt + self.dropout3(tgt2) 255 | return tgt 256 | 257 | def forward(self, tgt, memory, 258 | tgt_mask: Optional[Tensor] = None, 259 | memory_mask: Optional[Tensor] = None, 260 | tgt_key_padding_mask: Optional[Tensor] = None, 261 | memory_key_padding_mask: Optional[Tensor] = None, 262 | pos: Optional[Tensor] = None, 263 | query_pos: Optional[Tensor] = None): 264 | if self.normalize_before: 265 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 266 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 267 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 268 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 269 | 270 | 271 | class DecoderEmbeddings(nn.Module): 272 | def __init__(self, config): 273 | super().__init__() 274 | self.word_embeddings = nn.Embedding( 275 | config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id) 276 | self.position_embeddings = nn.Embedding( 277 | config.max_position_embeddings, config.hidden_dim 278 | ) 279 | 280 | self.LayerNorm = torch.nn.LayerNorm( 281 | config.hidden_dim, eps=config.layer_norm_eps) 282 | self.dropout = nn.Dropout(config.dropout) 283 | 284 | def forward(self, x): 285 | input_shape = x.size() 286 | seq_length = input_shape[1] 287 | device = x.device 288 | 289 | position_ids = torch.arange( 290 | seq_length, dtype=torch.long, device=device) 291 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 292 | 293 | input_embeds = self.word_embeddings(x) 294 | position_embeds = self.position_embeddings(position_ids) 295 | 296 | embeddings = input_embeds + position_embeds 297 | embeddings = self.LayerNorm(embeddings) 298 | embeddings = self.dropout(embeddings) 299 | 300 | return embeddings 301 | 302 | 303 | def _get_clones(module, N): 304 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 305 | 306 | 307 | def _get_activation_fn(activation): 308 | """Return an activation function given a string""" 309 | if activation == "relu": 310 | return F.relu 311 | if activation == "gelu": 312 | return F.gelu 313 | if activation == "glu": 314 | return F.glu 315 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 316 | 317 | 318 | def generate_square_subsequent_mask(sz): 319 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 320 | Unmasked positions are filled with float(0.0). 321 | """ 322 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 323 | mask = mask.float().masked_fill(mask == 0, float( 324 | '-inf')).masked_fill(mask == 1, float(0.0)) 325 | return mask 326 | 327 | 328 | def build_transformer(config): 329 | return Transformer( 330 | config, 331 | d_model=config.hidden_dim, 332 | dropout=config.dropout, 333 | nhead=config.nheads, 334 | dim_feedforward=config.dim_feedforward, 335 | num_encoder_layers=config.enc_layers, 336 | num_decoder_layers=config.dec_layers, 337 | normalize_before=config.pre_norm, 338 | return_intermediate_dec=False, 339 | ) 340 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from typing import List, Optional 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch import Tensor 7 | 8 | 9 | def _max_by_axis(the_list): 10 | # type: (List[List[int]]) -> List[int] 11 | maxes = the_list[0] 12 | for sublist in the_list[1:]: 13 | for index, item in enumerate(sublist): 14 | maxes[index] = max(maxes[index], item) 15 | return maxes 16 | 17 | 18 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 19 | # TODO make this more general 20 | if tensor_list[0].ndim == 3: 21 | # TODO make it support different-sized images 22 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 23 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 24 | batch_shape = [len(tensor_list)] + max_size 25 | b, c, h, w = batch_shape 26 | dtype = tensor_list[0].dtype 27 | device = tensor_list[0].device 28 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 29 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 30 | for img, pad_img, m in zip(tensor_list, tensor, mask): 31 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 32 | m[: img.shape[1], :img.shape[2]] = False 33 | else: 34 | raise ValueError('not supported') 35 | return NestedTensor(tensor, mask) 36 | 37 | 38 | class NestedTensor(object): 39 | def __init__(self, tensors, mask: Optional[Tensor]): 40 | self.tensors = tensors 41 | self.mask = mask 42 | 43 | def to(self, device): 44 | # type: (Device) -> NestedTensor # noqa 45 | cast_tensor = self.tensors.to(device) 46 | mask = self.mask 47 | if mask is not None: 48 | assert mask is not None 49 | cast_mask = mask.to(device) 50 | else: 51 | cast_mask = None 52 | return NestedTensor(cast_tensor, cast_mask) 53 | 54 | def decompose(self): 55 | return self.tensors, self.mask 56 | 57 | def __repr__(self): 58 | return str(self.tensors) 59 | 60 | 61 | def is_dist_avail_and_initialized(): 62 | if not dist.is_available(): 63 | return False 64 | if not dist.is_initialized(): 65 | return False 66 | return True 67 | 68 | 69 | def get_rank(): 70 | if not is_dist_avail_and_initialized(): 71 | return 0 72 | return dist.get_rank() 73 | 74 | 75 | def is_main_process(): 76 | return get_rank() == 0 77 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import BertTokenizer 4 | from PIL import Image 5 | import argparse 6 | 7 | from models import caption 8 | from datasets import coco, utils 9 | from configuration import Config 10 | import os 11 | 12 | parser = argparse.ArgumentParser(description='Image Captioning') 13 | parser.add_argument('--path', type=str, help='path to image', required=True) 14 | parser.add_argument('--v', type=str, help='version', default='v3') 15 | parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None) 16 | args = parser.parse_args() 17 | image_path = args.path 18 | version = args.v 19 | checkpoint_path = args.checkpoint 20 | 21 | config = Config() 22 | 23 | if version == 'v1': 24 | model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True) 25 | elif version == 'v2': 26 | model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True) 27 | elif version == 'v3': 28 | model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True) 29 | else: 30 | print("Checking for checkpoint.") 31 | if checkpoint_path is None: 32 | raise NotImplementedError('No model to chose from!') 33 | else: 34 | if not os.path.exists(checkpoint_path): 35 | raise NotImplementedError('Give valid checkpoint path') 36 | print("Found checkpoint! Loading!") 37 | model,_ = caption.build_model(config) 38 | print("Loading Checkpoint...") 39 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 40 | model.load_state_dict(checkpoint['model']) 41 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 42 | 43 | start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token) 44 | end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token) 45 | 46 | image = Image.open(image_path) 47 | image = coco.val_transform(image) 48 | image = image.unsqueeze(0) 49 | 50 | 51 | def create_caption_and_mask(start_token, max_length): 52 | caption_template = torch.zeros((1, max_length), dtype=torch.long) 53 | mask_template = torch.ones((1, max_length), dtype=torch.bool) 54 | 55 | caption_template[:, 0] = start_token 56 | mask_template[:, 0] = False 57 | 58 | return caption_template, mask_template 59 | 60 | 61 | caption, cap_mask = create_caption_and_mask( 62 | start_token, config.max_position_embeddings) 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(): 67 | model.eval() 68 | for i in range(config.max_position_embeddings - 1): 69 | predictions = model(image, caption, cap_mask) 70 | predictions = predictions[:, i, :] 71 | predicted_id = torch.argmax(predictions, axis=-1) 72 | 73 | if predicted_id[0] == 102: 74 | return caption 75 | 76 | caption[:, i+1] = predicted_id[0] 77 | cap_mask[:, i+1] = False 78 | 79 | return caption 80 | 81 | 82 | output = evaluate() 83 | result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True) 84 | #result = tokenizer.decode(output[0], skip_special_tokens=True) 85 | print(result.capitalize()) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | transformers 5 | tqdm --------------------------------------------------------------------------------