├── .github
├── airplane.png
├── cake.png
├── cruise.png
├── girl.png
├── horse.png
└── office.png
├── .gitignore
├── LICENSE
├── README.md
├── catr_demo.ipynb
├── configuration.py
├── datasets
├── __init__.py
├── coco.py
└── utils.py
├── engine.py
├── finetune.py
├── hubconf.py
├── main.py
├── models
├── __init__.py
├── backbone.py
├── caption.py
├── position_encoding.py
├── transformer.py
└── utils.py
├── predict.py
└── requirements.txt
/.github/airplane.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/airplane.png
--------------------------------------------------------------------------------
/.github/cake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/cake.png
--------------------------------------------------------------------------------
/.github/cruise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/cruise.png
--------------------------------------------------------------------------------
/.github/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/girl.png
--------------------------------------------------------------------------------
/.github/horse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/horse.png
--------------------------------------------------------------------------------
/.github/office.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/.github/office.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # pytype static type analyzer
137 | .pytype/
138 |
139 | # End of https://www.toptal.com/developers/gitignore/api/python
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **CA⫶TR**: Image Captioning with Transformers
2 | ========
3 | PyTorch training code and pretrained models for **CATR** (**CA**ption **TR**ansformer).
4 |
5 | The models are also available via torch hub,
6 | to load model with pretrained weights simply do:
7 | ```python
8 | model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True) # you can choose between v1, v2 and v3
9 | ```
10 | ### Samples:
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | All these images has been annotated by CATR.
21 |
22 | Test with your own bunch of images:
23 | ````bash
24 | $ python predict.py --path /path/to/image --v v2 // You can choose between v1, v2, v3 [default is v3]
25 | ````
26 | Or Try it out in colab [notebook](catr_demo.ipynb)
27 |
28 | # Usage
29 | There are no extra compiled components in CATR and package dependencies are minimal,
30 | so the code is very simple to use. We provide instructions how to install dependencies.
31 | First, clone the repository locally:
32 | ```
33 | $ git clone https://github.com/saahiluppal/catr.git
34 | ```
35 | Then, install PyTorch 1.5+ and torchvision 0.6+ along with remaining dependencies:
36 | ```
37 | $ pip install -r requirements.txt
38 | ```
39 | That's it, should be good to train and test caption models.
40 |
41 | ## Data preparation
42 |
43 | Download and extract COCO 2017 train and val images with annotations from
44 | [http://cocodataset.org](http://cocodataset.org/#download).
45 | We expect the directory structure to be the following:
46 | ```
47 | path/to/coco/
48 | annotations/ # annotation json files
49 | train2017/ # train images
50 | val2017/ # val images
51 | ```
52 |
53 | ## Training
54 | Tweak the hyperparameters from configuration file.
55 |
56 | To train baseline CATR on a single GPU for 30 epochs run:
57 | ```
58 | $ python main.py
59 | ```
60 | We train CATR with AdamW setting learning rate in the transformer to 1e-4 and 1e-5 in the backbone.
61 | Horizontal flips, scales an crops are used for augmentation.
62 | Images are rescaled to have max size 299.
63 | The transformer is trained with dropout of 0.1, and the whole model is trained with grad clip of 0.1.
64 |
65 | ## Testing
66 | To test CATR with your own images.
67 | ```
68 | $ python predict.py --path /path/to/image --v v2 // You can choose between v1, v2, v3 [default is v3]
69 | ```
70 |
71 | # License
72 | CATR is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
73 |
--------------------------------------------------------------------------------
/catr_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "catr-demo.ipynb",
7 | "provenance": [],
8 | "mount_file_id": "1LVtj77AoMmhJmTDCsXmcQjqxUFb7YEJ2",
9 | "authorship_tag": "ABX9TyNQVSIJZ6Hoj2qvhWWnmSyV",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "N8xyiUNoWmL3",
33 | "colab_type": "code",
34 | "colab": {
35 | "base_uri": "https://localhost:8080/",
36 | "height": 121
37 | },
38 | "outputId": "6c54de9e-339e-47a6-ff49-50dfe361d627"
39 | },
40 | "source": [
41 | "!git clone https://github.com/saahiluppal/catr.git"
42 | ],
43 | "execution_count": 2,
44 | "outputs": [
45 | {
46 | "output_type": "stream",
47 | "text": [
48 | "Cloning into 'catr'...\n",
49 | "remote: Enumerating objects: 79, done.\u001b[K\n",
50 | "remote: Counting objects: 1% (1/79)\u001b[K\rremote: Counting objects: 2% (2/79)\u001b[K\rremote: Counting objects: 3% (3/79)\u001b[K\rremote: Counting objects: 5% (4/79)\u001b[K\rremote: Counting objects: 6% (5/79)\u001b[K\rremote: Counting objects: 7% (6/79)\u001b[K\rremote: Counting objects: 8% (7/79)\u001b[K\rremote: Counting objects: 10% (8/79)\u001b[K\rremote: Counting objects: 11% (9/79)\u001b[K\rremote: Counting objects: 12% (10/79)\u001b[K\rremote: Counting objects: 13% (11/79)\u001b[K\rremote: Counting objects: 15% (12/79)\u001b[K\rremote: Counting objects: 16% (13/79)\u001b[K\rremote: Counting objects: 17% (14/79)\u001b[K\rremote: Counting objects: 18% (15/79)\u001b[K\rremote: Counting objects: 20% (16/79)\u001b[K\rremote: Counting objects: 21% (17/79)\u001b[K\rremote: Counting objects: 22% (18/79)\u001b[K\rremote: Counting objects: 24% (19/79)\u001b[K\rremote: Counting objects: 25% (20/79)\u001b[K\rremote: Counting objects: 26% (21/79)\u001b[K\rremote: Counting objects: 27% (22/79)\u001b[K\rremote: Counting objects: 29% (23/79)\u001b[K\rremote: Counting objects: 30% (24/79)\u001b[K\rremote: Counting objects: 31% (25/79)\u001b[K\rremote: Counting objects: 32% (26/79)\u001b[K\rremote: Counting objects: 34% (27/79)\u001b[K\rremote: Counting objects: 35% (28/79)\u001b[K\rremote: Counting objects: 36% (29/79)\u001b[K\rremote: Counting objects: 37% (30/79)\u001b[K\rremote: Counting objects: 39% (31/79)\u001b[K\rremote: Counting objects: 40% (32/79)\u001b[K\rremote: Counting objects: 41% (33/79)\u001b[K\rremote: Counting objects: 43% (34/79)\u001b[K\rremote: Counting objects: 44% (35/79)\u001b[K\rremote: Counting objects: 45% (36/79)\u001b[K\rremote: Counting objects: 46% (37/79)\u001b[K\rremote: Counting objects: 48% (38/79)\u001b[K\rremote: Counting objects: 49% (39/79)\u001b[K\rremote: Counting objects: 50% (40/79)\u001b[K\rremote: Counting objects: 51% (41/79)\u001b[K\rremote: Counting objects: 53% (42/79)\u001b[K\rremote: Counting objects: 54% (43/79)\u001b[K\rremote: Counting objects: 55% (44/79)\u001b[K\rremote: Counting objects: 56% (45/79)\u001b[K\rremote: Counting objects: 58% (46/79)\u001b[K\rremote: Counting objects: 59% (47/79)\u001b[K\rremote: Counting objects: 60% (48/79)\u001b[K\rremote: Counting objects: 62% (49/79)\u001b[K\rremote: Counting objects: 63% (50/79)\u001b[K\rremote: Counting objects: 64% (51/79)\u001b[K\rremote: Counting objects: 65% (52/79)\u001b[K\rremote: Counting objects: 67% (53/79)\u001b[K\rremote: Counting objects: 68% (54/79)\u001b[K\rremote: Counting objects: 69% (55/79)\u001b[K\rremote: Counting objects: 70% (56/79)\u001b[K\rremote: Counting objects: 72% (57/79)\u001b[K\rremote: Counting objects: 73% (58/79)\u001b[K\rremote: Counting objects: 74% (59/79)\u001b[K\rremote: Counting objects: 75% (60/79)\u001b[K\rremote: Counting objects: 77% (61/79)\u001b[K\rremote: Counting objects: 78% (62/79)\u001b[K\rremote: Counting objects: 79% (63/79)\u001b[K\rremote: Counting objects: 81% (64/79)\u001b[K\rremote: Counting objects: 82% (65/79)\u001b[K\rremote: Counting objects: 83% (66/79)\u001b[K\rremote: Counting objects: 84% (67/79)\u001b[K\rremote: Counting objects: 86% (68/79)\u001b[K\rremote: Counting objects: 87% (69/79)\u001b[K\rremote: Counting objects: 88% (70/79)\u001b[K\rremote: Counting objects: 89% (71/79)\u001b[K\rremote: Counting objects: 91% (72/79)\u001b[K\rremote: Counting objects: 92% (73/79)\u001b[K\rremote: Counting objects: 93% (74/79)\u001b[K\rremote: Counting objects: 94% (75/79)\u001b[K\rremote: Counting objects: 96% (76/79)\u001b[K\rremote: Counting objects: 97% (77/79)\u001b[K\rremote: Counting objects: 98% (78/79)\u001b[K\rremote: Counting objects: 100% (79/79)\u001b[K\rremote: Counting objects: 100% (79/79), done.\u001b[K\n",
51 | "remote: Compressing objects: 100% (58/58), done.\u001b[K\n",
52 | "remote: Total 79 (delta 29), reused 57 (delta 18), pack-reused 0\u001b[K\n",
53 | "Unpacking objects: 100% (79/79), done.\n"
54 | ],
55 | "name": "stdout"
56 | }
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "metadata": {
62 | "id": "JMzOLu6Tjhqz",
63 | "colab_type": "code",
64 | "colab": {
65 | "base_uri": "https://localhost:8080/",
66 | "height": 34
67 | },
68 | "outputId": "f2df8dc5-3bce-42c2-f03e-81bf10a0a8ab"
69 | },
70 | "source": [
71 | "%cd catr/"
72 | ],
73 | "execution_count": 4,
74 | "outputs": [
75 | {
76 | "output_type": "stream",
77 | "text": [
78 | "/content/catr\n"
79 | ],
80 | "name": "stdout"
81 | }
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "metadata": {
87 | "id": "ojT0Ln7GjayV",
88 | "colab_type": "code",
89 | "colab": {
90 | "base_uri": "https://localhost:8080/",
91 | "height": 697
92 | },
93 | "outputId": "618844dd-0459-4d3d-b71b-d5ed11ec8bad"
94 | },
95 | "source": [
96 | "!pip install -r requirements.txt"
97 | ],
98 | "execution_count": 5,
99 | "outputs": [
100 | {
101 | "output_type": "stream",
102 | "text": [
103 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 1)) (1.5.1+cu101)\n",
104 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 2)) (0.6.1+cu101)\n",
105 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 3)) (1.18.5)\n",
106 | "Collecting transformers\n",
107 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)\n",
108 | "\r\u001b[K |▍ | 10kB 24.9MB/s eta 0:00:01\r\u001b[K |▉ | 20kB 3.3MB/s eta 0:00:01\r\u001b[K |█▎ | 30kB 4.4MB/s eta 0:00:01\r\u001b[K |█▊ | 40kB 4.6MB/s eta 0:00:01\r\u001b[K |██▏ | 51kB 3.8MB/s eta 0:00:01\r\u001b[K |██▋ | 61kB 4.3MB/s eta 0:00:01\r\u001b[K |███ | 71kB 4.7MB/s eta 0:00:01\r\u001b[K |███▍ | 81kB 4.9MB/s eta 0:00:01\r\u001b[K |███▉ | 92kB 5.3MB/s eta 0:00:01\r\u001b[K |████▎ | 102kB 5.3MB/s eta 0:00:01\r\u001b[K |████▊ | 112kB 5.3MB/s eta 0:00:01\r\u001b[K |█████▏ | 122kB 5.3MB/s eta 0:00:01\r\u001b[K |█████▌ | 133kB 5.3MB/s eta 0:00:01\r\u001b[K |██████ | 143kB 5.3MB/s eta 0:00:01\r\u001b[K |██████▍ | 153kB 5.3MB/s eta 0:00:01\r\u001b[K |██████▉ | 163kB 5.3MB/s eta 0:00:01\r\u001b[K |███████▎ | 174kB 5.3MB/s eta 0:00:01\r\u001b[K |███████▊ | 184kB 5.3MB/s eta 0:00:01\r\u001b[K |████████ | 194kB 5.3MB/s eta 0:00:01\r\u001b[K |████████▌ | 204kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████ | 215kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████▍ | 225kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████▉ | 235kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████▎ | 245kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████▋ | 256kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████ | 266kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████▌ | 276kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████ | 286kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████▍ | 296kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████▉ | 307kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████▏ | 317kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████▋ | 327kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████ | 337kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████▌ | 348kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████ | 358kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████▍ | 368kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 378kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 389kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████▋ | 399kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████ | 409kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████▌ | 419kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████ | 430kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 440kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 450kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 460kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████▋ | 471kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████ | 481kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 491kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 501kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 512kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 522kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▏ | 532kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▋ | 542kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 552kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 563kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████▉ | 573kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 583kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 593kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 604kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 614kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 624kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 634kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████▉ | 645kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 655kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████▊ | 665kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 675kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 686kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 696kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▍ | 706kB 5.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▉ | 716kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 727kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 737kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 747kB 5.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 757kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 768kB 5.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 778kB 5.3MB/s \n",
109 | "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 5)) (4.41.1)\n",
110 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->-r requirements.txt (line 1)) (0.16.0)\n",
111 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->-r requirements.txt (line 2)) (7.0.0)\n",
112 | "Collecting tokenizers==0.8.1.rc1\n",
113 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)\n",
114 | "\u001b[K |████████████████████████████████| 3.0MB 24.6MB/s \n",
115 | "\u001b[?25hRequirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (0.7)\n",
116 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (2.23.0)\n",
117 | "Collecting sacremoses\n",
118 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
119 | "\u001b[K |████████████████████████████████| 890kB 31.3MB/s \n",
120 | "\u001b[?25hCollecting sentencepiece!=0.1.92\n",
121 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
122 | "\u001b[K |████████████████████████████████| 1.1MB 49.5MB/s \n",
123 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (20.4)\n",
124 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (2019.12.20)\n",
125 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers->-r requirements.txt (line 4)) (3.0.12)\n",
126 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2.10)\n",
127 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (1.24.3)\n",
128 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2020.6.20)\n",
129 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (3.0.4)\n",
130 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (1.15.0)\n",
131 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (7.1.2)\n",
132 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (0.16.0)\n",
133 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers->-r requirements.txt (line 4)) (2.4.7)\n",
134 | "Building wheels for collected packages: sacremoses\n",
135 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
136 | " Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=54b093ea3f7db7b5b98464934131d04f8580c06b274673e0cdf946240bc124f0\n",
137 | " Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
138 | "Successfully built sacremoses\n",
139 | "Installing collected packages: tokenizers, sacremoses, sentencepiece, transformers\n",
140 | "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc1 transformers-3.0.2\n"
141 | ],
142 | "name": "stdout"
143 | }
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "metadata": {
149 | "id": "KesOqe40jdak",
150 | "colab_type": "code",
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 176
154 | },
155 | "outputId": "54e430ea-f037-4541-d0d4-ac71ee457ee3"
156 | },
157 | "source": [
158 | "!python predict.py --path .github/cake.png"
159 | ],
160 | "execution_count": 6,
161 | "outputs": [
162 | {
163 | "output_type": "stream",
164 | "text": [
165 | "2020-07-30 11:50:09.260991: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n",
166 | "Downloading: \"https://github.com/saahiluppal/catr/archive/master.zip\" to /root/.cache/torch/hub/master.zip\n",
167 | "Downloading: \"https://download.pytorch.org/models/resnet101-5d3b4d8f.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth\n",
168 | "100% 170M/170M [00:00<00:00, 227MB/s]\n",
169 | "Downloading: \"https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth\" to /root/.cache/torch/hub/checkpoints/weights_9348032.pth\n",
170 | "100% 322M/322M [00:07<00:00, 45.0MB/s]\n",
171 | "Downloading: 100% 232k/232k [00:00<00:00, 916kB/s]\n",
172 | "A person cutting a cake with a knife.\n"
173 | ],
174 | "name": "stdout"
175 | }
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "metadata": {
181 | "id": "8VAr2dFajs69",
182 | "colab_type": "code",
183 | "colab": {
184 | "base_uri": "https://localhost:8080/",
185 | "height": 89
186 | },
187 | "outputId": "7c985740-9d46-43a9-bcaf-822a6952b268"
188 | },
189 | "source": [
190 | "!python predict.py --path .github/girl.png"
191 | ],
192 | "execution_count": 7,
193 | "outputs": [
194 | {
195 | "output_type": "stream",
196 | "text": [
197 | "2020-07-30 11:50:48.364873: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n",
198 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n",
199 | "A woman sitting on a curb holding a pink umbrella.\n"
200 | ],
201 | "name": "stdout"
202 | }
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "metadata": {
208 | "id": "_evy1rElj3JK",
209 | "colab_type": "code",
210 | "colab": {
211 | "base_uri": "https://localhost:8080/",
212 | "height": 89
213 | },
214 | "outputId": "37eeaf43-3d7c-41db-bb59-287bd7632da9"
215 | },
216 | "source": [
217 | "!python predict.py --path .github/office.png"
218 | ],
219 | "execution_count": 8,
220 | "outputs": [
221 | {
222 | "output_type": "stream",
223 | "text": [
224 | "2020-07-30 11:51:07.946459: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n",
225 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n",
226 | "A group of people sitting at a table with laptops.\n"
227 | ],
228 | "name": "stdout"
229 | }
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "metadata": {
235 | "id": "5uqoo0NCj6h-",
236 | "colab_type": "code",
237 | "colab": {
238 | "base_uri": "https://localhost:8080/",
239 | "height": 89
240 | },
241 | "outputId": "89df63ee-77e6-4acc-e9c3-9e6e05a9f517"
242 | },
243 | "source": [
244 | "!python predict.py --path .github/horse.png"
245 | ],
246 | "execution_count": 9,
247 | "outputs": [
248 | {
249 | "output_type": "stream",
250 | "text": [
251 | "2020-07-30 11:51:26.572680: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n",
252 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n",
253 | "A man riding on the back of a brown horse.\n"
254 | ],
255 | "name": "stdout"
256 | }
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "metadata": {
262 | "id": "CCnUN_H4j7x9",
263 | "colab_type": "code",
264 | "colab": {
265 | "base_uri": "https://localhost:8080/",
266 | "height": 89
267 | },
268 | "outputId": "c095a09c-eebd-41ce-b7a2-df9af8100dc0"
269 | },
270 | "source": [
271 | "!python predict.py --path .github/airplane.png"
272 | ],
273 | "execution_count": 10,
274 | "outputs": [
275 | {
276 | "output_type": "stream",
277 | "text": [
278 | "2020-07-30 11:51:44.833782: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n",
279 | "Using cache found in /root/.cache/torch/hub/saahiluppal_catr_master\n",
280 | "A large jetliner sitting on top of an airport runway.\n"
281 | ],
282 | "name": "stdout"
283 | }
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "metadata": {
289 | "id": "x8YuZRMokBqT",
290 | "colab_type": "code",
291 | "colab": {}
292 | },
293 | "source": [
294 | ""
295 | ],
296 | "execution_count": null,
297 | "outputs": []
298 | }
299 | ]
300 | }
--------------------------------------------------------------------------------
/configuration.py:
--------------------------------------------------------------------------------
1 | class Config(object):
2 | def __init__(self):
3 |
4 | # Learning Rates
5 | self.lr_backbone = 1e-5
6 | self.lr = 1e-4
7 |
8 | # Epochs
9 | self.epochs = 30
10 | self.lr_drop = 20
11 | self.start_epoch = 0
12 | self.weight_decay = 1e-4
13 |
14 | # Backbone
15 | self.backbone = 'resnet101'
16 | self.position_embedding = 'sine'
17 | self.dilation = True
18 |
19 | # Basic
20 | self.device = 'cuda'
21 | self.seed = 42
22 | self.batch_size = 32
23 | self.num_workers = 8
24 | self.checkpoint = './checkpoint.pth'
25 | self.clip_max_norm = 0.1
26 |
27 | # Transformer
28 | self.hidden_dim = 256
29 | self.pad_token_id = 0
30 | self.max_position_embeddings = 128
31 | self.layer_norm_eps = 1e-12
32 | self.dropout = 0.1
33 | self.vocab_size = 30522
34 |
35 | self.enc_layers = 6
36 | self.dec_layers = 6
37 | self.dim_feedforward = 2048
38 | self.nheads = 8
39 | self.pre_norm = True
40 |
41 | # Dataset
42 | self.dir = '../coco'
43 | self.limit = -1
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/coco.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torchvision.transforms.functional as TF
3 | import torchvision as tv
4 |
5 | from PIL import Image
6 | import numpy as np
7 | import random
8 | import os
9 |
10 | from transformers import BertTokenizer
11 |
12 | from .utils import nested_tensor_from_tensor_list, read_json
13 |
14 | MAX_DIM = 299
15 |
16 |
17 | def under_max(image):
18 | if image.mode != 'RGB':
19 | image = image.convert("RGB")
20 |
21 | shape = np.array(image.size, dtype=np.float)
22 | long_dim = max(shape)
23 | scale = MAX_DIM / long_dim
24 |
25 | new_shape = (shape * scale).astype(int)
26 | image = image.resize(new_shape)
27 |
28 | return image
29 |
30 |
31 | class RandomRotation:
32 | def __init__(self, angles=[0, 90, 180, 270]):
33 | self.angles = angles
34 |
35 | def __call__(self, x):
36 | angle = random.choice(self.angles)
37 | return TF.rotate(x, angle, expand=True)
38 |
39 |
40 | train_transform = tv.transforms.Compose([
41 | RandomRotation(),
42 | tv.transforms.Lambda(under_max),
43 | tv.transforms.ColorJitter(brightness=[0.5, 1.3], contrast=[
44 | 0.8, 1.5], saturation=[0.2, 1.5]),
45 | tv.transforms.RandomHorizontalFlip(),
46 | tv.transforms.ToTensor(),
47 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
48 | ])
49 |
50 | val_transform = tv.transforms.Compose([
51 | tv.transforms.Lambda(under_max),
52 | tv.transforms.ToTensor(),
53 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
54 | ])
55 |
56 |
57 | class CocoCaption(Dataset):
58 | def __init__(self, root, ann, max_length, limit, transform=train_transform, mode='training'):
59 | super().__init__()
60 |
61 | self.root = root
62 | self.transform = transform
63 | self.annot = [(self._process(val['image_id']), val['caption'])
64 | for val in ann['annotations']]
65 | if mode == 'validation':
66 | self.annot = self.annot
67 | if mode == 'training':
68 | self.annot = self.annot[: limit]
69 |
70 | self.tokenizer = BertTokenizer.from_pretrained(
71 | 'bert-base-uncased', do_lower=True)
72 | self.max_length = max_length + 1
73 |
74 | def _process(self, image_id):
75 | val = str(image_id).zfill(12)
76 | return val + '.jpg'
77 |
78 | def __len__(self):
79 | return len(self.annot)
80 |
81 | def __getitem__(self, idx):
82 | image_id, caption = self.annot[idx]
83 | image = Image.open(os.path.join(self.root, image_id))
84 |
85 | if self.transform:
86 | image = self.transform(image)
87 | image = nested_tensor_from_tensor_list(image.unsqueeze(0))
88 |
89 | caption_encoded = self.tokenizer.encode_plus(
90 | caption, max_length=self.max_length, pad_to_max_length=True, return_attention_mask=True, return_token_type_ids=False, truncation=True)
91 |
92 | caption = np.array(caption_encoded['input_ids'])
93 | cap_mask = (
94 | 1 - np.array(caption_encoded['attention_mask'])).astype(bool)
95 |
96 | return image.tensors.squeeze(0), image.mask.squeeze(0), caption, cap_mask
97 |
98 |
99 | def build_dataset(config, mode='training'):
100 | if mode == 'training':
101 | train_dir = os.path.join(config.dir, 'train2017')
102 | train_file = os.path.join(
103 | config.dir, 'annotations', 'captions_train2017.json')
104 | data = CocoCaption(train_dir, read_json(
105 | train_file), max_length=config.max_position_embeddings, limit=config.limit, transform=train_transform, mode='training')
106 | return data
107 |
108 | elif mode == 'validation':
109 | val_dir = os.path.join(config.dir, 'val2017')
110 | val_file = os.path.join(
111 | config.dir, 'annotations', 'captions_val2017.json')
112 | data = CocoCaption(val_dir, read_json(
113 | val_file), max_length=config.max_position_embeddings, limit=config.limit, transform=val_transform, mode='validation')
114 | return data
115 |
116 | else:
117 | raise NotImplementedError(f"{mode} not supported")
118 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, List
3 | from torch import Tensor
4 |
5 | import json
6 | import os
7 |
8 | MAX_DIM = 299
9 |
10 | def read_json(file_name):
11 | with open(file_name) as handle:
12 | out = json.load(handle)
13 | return out
14 |
15 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
16 | # TODO make this more general
17 | if tensor_list[0].ndim == 3:
18 | # TODO make it support different-sized images
19 | max_size = [3, MAX_DIM, MAX_DIM]
20 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
21 | batch_shape = [len(tensor_list)] + max_size
22 | b, c, h, w = batch_shape
23 | dtype = tensor_list[0].dtype
24 | device = tensor_list[0].device
25 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
26 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
27 | for img, pad_img, m in zip(tensor_list, tensor, mask):
28 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
29 | m[: img.shape[1], :img.shape[2]] = False
30 | else:
31 | raise ValueError('not supported')
32 | return NestedTensor(tensor, mask)
33 |
34 |
35 | class NestedTensor(object):
36 | def __init__(self, tensors, mask: Optional[Tensor]):
37 | self.tensors = tensors
38 | self.mask = mask
39 |
40 | def to(self, device):
41 | # type: (Device) -> NestedTensor # noqa
42 | cast_tensor = self.tensors.to(device)
43 | mask = self.mask
44 | if mask is not None:
45 | assert mask is not None
46 | cast_mask = mask.to(device)
47 | else:
48 | cast_mask = None
49 | return NestedTensor(cast_tensor, cast_mask)
50 |
51 | def decompose(self):
52 | return self.tensors, self.mask
53 |
54 | def __repr__(self):
55 | return str(self.tensors)
56 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import torch
3 |
4 | import math
5 | import sys
6 | import tqdm
7 |
8 | from models import utils
9 |
10 |
11 | def train_one_epoch(model, criterion, data_loader,
12 | optimizer, device, epoch, max_norm):
13 | model.train()
14 | criterion.train()
15 |
16 | epoch_loss = 0.0
17 | total = len(data_loader)
18 |
19 | with tqdm.tqdm(total=total) as pbar:
20 | for images, masks, caps, cap_masks in data_loader:
21 | samples = utils.NestedTensor(images, masks).to(device)
22 | caps = caps.to(device)
23 | cap_masks = cap_masks.to(device)
24 |
25 | outputs = model(samples, caps[:, :-1], cap_masks[:, :-1])
26 | loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:])
27 | loss_value = loss.item()
28 | epoch_loss += loss_value
29 |
30 | if not math.isfinite(loss_value):
31 | print(f'Loss is {loss_value}, stopping training')
32 | sys.exit(1)
33 |
34 | optimizer.zero_grad()
35 | loss.backward()
36 | if max_norm > 0:
37 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
38 | optimizer.step()
39 |
40 | pbar.update(1)
41 |
42 | return epoch_loss / total
43 |
44 | @torch.no_grad()
45 | def evaluate(model, criterion, data_loader, device):
46 | model.eval()
47 | criterion.eval()
48 |
49 | validation_loss = 0.0
50 | total = len(data_loader)
51 |
52 | with tqdm.tqdm(total=total) as pbar:
53 | for images, masks, caps, cap_masks in data_loader:
54 | samples = utils.NestedTensor(images, masks).to(device)
55 | caps = caps.to(device)
56 | cap_masks = cap_masks.to(device)
57 |
58 | outputs = model(samples, caps[:, :-1], cap_masks[:, :-1])
59 | loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:])
60 |
61 | validation_loss += loss.item()
62 |
63 | pbar.update(1)
64 |
65 | return validation_loss / total
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | import numpy as np
5 | import time
6 | import sys
7 | import os
8 |
9 | from models import utils, caption
10 | from datasets import coco
11 | from configuration import Config
12 | from engine import train_one_epoch, evaluate
13 |
14 |
15 | def finetune(config):
16 | device = torch.device(config.device)
17 | print(f'Initializing Device: {device}')
18 |
19 | seed = config.seed + utils.get_rank()
20 | torch.manual_seed(seed)
21 | np.random.seed(seed)
22 |
23 | model, criterion = caption.build_model(config)
24 | checkpoint = torch.hub.load_state_dict_from_url(
25 | url="https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth",
26 | map_location=device
27 | )
28 | model.to(device)
29 | model.load_state_dict(checkpoint['model'])
30 |
31 | config.lr = 1e-5
32 | config.epochs = 10
33 | config.lr_drop = 8
34 |
35 | n_parameters = sum(p.numel()
36 | for p in model.parameters() if p.requires_grad)
37 | print(f"Number of params: {n_parameters}")
38 |
39 | param_dicts = [
40 | {"params": [p for n, p in model.named_parameters(
41 | ) if "backbone" not in n and p.requires_grad]},
42 | {
43 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
44 | "lr": config.lr_backbone,
45 | },
46 | ]
47 | optimizer = torch.optim.AdamW(
48 | param_dicts, lr=config.lr, weight_decay=config.weight_decay)
49 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)
50 |
51 | dataset_train = coco.build_dataset(config, mode='training')
52 | dataset_val = coco.build_dataset(config, mode='validation')
53 | print(f"Train: {len(dataset_train)}")
54 | print(f"Valid: {len(dataset_val)}")
55 |
56 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
57 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
58 |
59 | batch_sampler_train = torch.utils.data.BatchSampler(
60 | sampler_train, config.batch_size, drop_last=True
61 | )
62 |
63 | data_loader_train = DataLoader(
64 | dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers)
65 | data_loader_val = DataLoader(dataset_val, config.batch_size,
66 | sampler=sampler_val, drop_last=False, num_workers=config.num_workers)
67 |
68 | if os.path.exists(config.checkpoint):
69 | print("Loading Checkpoint...")
70 | checkpoint = torch.load(config.checkpoint, map_location='cpu')
71 | model.load_state_dict(checkpoint['model'])
72 | optimizer.load_state_dict(checkpoint['optimizer'])
73 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
74 | config.start_epoch = checkpoint['epoch'] + 1
75 |
76 | print("Start Training..")
77 | for epoch in range(config.start_epoch, config.epochs):
78 | print(f"Epoch: {epoch}")
79 | epoch_loss = train_one_epoch(
80 | model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm)
81 | lr_scheduler.step()
82 | print(f"Training Loss: {epoch_loss}")
83 |
84 | torch.save({
85 | 'model': model.state_dict(),
86 | 'optimizer': optimizer.state_dict(),
87 | 'lr_scheduler': lr_scheduler.state_dict(),
88 | 'epoch': epoch,
89 | }, config.checkpoint)
90 |
91 | validation_loss = evaluate(model, criterion, data_loader_val, device)
92 | print(f"Validation Loss: {validation_loss}")
93 |
94 | print()
95 |
96 |
97 | if __name__ == "__main__":
98 | config = Config()
99 | finetune(config)
100 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models import caption
4 | from configuration import Config
5 |
6 | dependencies = ['torch', 'torchvision']
7 |
8 | def v1(pretrained=False):
9 | config = Config()
10 | model, _ = caption.build_model(config)
11 |
12 | if pretrained:
13 | checkpoint = torch.hub.load_state_dict_from_url(
14 | url='https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth',
15 | map_location='cpu'
16 | )
17 | model.load_state_dict(checkpoint['model'])
18 |
19 | return model
20 |
21 | def v2(pretrained=False):
22 | config = Config()
23 | model, _ = caption.build_model(config)
24 |
25 | if pretrained:
26 | checkpoint = torch.hub.load_state_dict_from_url(
27 | url='https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth',
28 | map_location='cpu'
29 | )
30 | model.load_state_dict(checkpoint['model'])
31 |
32 | return model
33 |
34 | def v3(pretrained=False):
35 | config = Config()
36 | model, _ = caption.build_model(config)
37 |
38 | if pretrained:
39 | checkpoint = torch.hub.load_state_dict_from_url(
40 | url='https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth',
41 | map_location='cpu'
42 | )
43 | model.load_state_dict(checkpoint['model'])
44 |
45 | return model
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | import numpy as np
5 | import time
6 | import sys
7 | import os
8 |
9 | from models import utils, caption
10 | from datasets import coco
11 | from configuration import Config
12 | from engine import train_one_epoch, evaluate
13 |
14 |
15 | def main(config):
16 | device = torch.device(config.device)
17 | print(f'Initializing Device: {device}')
18 |
19 | seed = config.seed + utils.get_rank()
20 | torch.manual_seed(seed)
21 | np.random.seed(seed)
22 |
23 | model, criterion = caption.build_model(config)
24 | model.to(device)
25 |
26 | n_parameters = sum(p.numel()
27 | for p in model.parameters() if p.requires_grad)
28 | print(f"Number of params: {n_parameters}")
29 |
30 | param_dicts = [
31 | {"params": [p for n, p in model.named_parameters(
32 | ) if "backbone" not in n and p.requires_grad]},
33 | {
34 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
35 | "lr": config.lr_backbone,
36 | },
37 | ]
38 | optimizer = torch.optim.AdamW(
39 | param_dicts, lr=config.lr, weight_decay=config.weight_decay)
40 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)
41 |
42 | dataset_train = coco.build_dataset(config, mode='training')
43 | dataset_val = coco.build_dataset(config, mode='validation')
44 | print(f"Train: {len(dataset_train)}")
45 | print(f"Valid: {len(dataset_val)}")
46 |
47 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
48 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
49 |
50 | batch_sampler_train = torch.utils.data.BatchSampler(
51 | sampler_train, config.batch_size, drop_last=True
52 | )
53 |
54 | data_loader_train = DataLoader(
55 | dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers)
56 | data_loader_val = DataLoader(dataset_val, config.batch_size,
57 | sampler=sampler_val, drop_last=False, num_workers=config.num_workers)
58 |
59 | if os.path.exists(config.checkpoint):
60 | print("Loading Checkpoint...")
61 | checkpoint = torch.load(config.checkpoint, map_location='cpu')
62 | model.load_state_dict(checkpoint['model'])
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
65 | config.start_epoch = checkpoint['epoch'] + 1
66 |
67 | print("Start Training..")
68 | for epoch in range(config.start_epoch, config.epochs):
69 | print(f"Epoch: {epoch}")
70 | epoch_loss = train_one_epoch(
71 | model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm)
72 | lr_scheduler.step()
73 | print(f"Training Loss: {epoch_loss}")
74 |
75 | torch.save({
76 | 'model': model.state_dict(),
77 | 'optimizer': optimizer.state_dict(),
78 | 'lr_scheduler': lr_scheduler.state_dict(),
79 | 'epoch': epoch,
80 | }, config.checkpoint)
81 |
82 | validation_loss = evaluate(model, criterion, data_loader_val, device)
83 | print(f"Validation Loss: {validation_loss}")
84 |
85 | print()
86 |
87 |
88 | if __name__ == "__main__":
89 | config = Config()
90 | main(config)
91 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saahiluppal/catr/fac82f9b4004b1dd39ccf89760b758ad19a2dbee/models/__init__.py
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torchvision
7 | from torch import nn
8 | from torchvision.models._utils import IntermediateLayerGetter
9 | from typing import Dict, List
10 |
11 | from .utils import NestedTensor, is_main_process
12 |
13 | from .position_encoding import build_position_encoding
14 |
15 |
16 | class FrozenBatchNorm2d(torch.nn.Module):
17 | """
18 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
19 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
20 | without which any other models than torchvision.models.resnet[18,34,50,101]
21 | produce nans.
22 | """
23 |
24 | def __init__(self, n):
25 | super(FrozenBatchNorm2d, self).__init__()
26 | self.register_buffer("weight", torch.ones(n))
27 | self.register_buffer("bias", torch.zeros(n))
28 | self.register_buffer("running_mean", torch.zeros(n))
29 | self.register_buffer("running_var", torch.ones(n))
30 |
31 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
32 | missing_keys, unexpected_keys, error_msgs):
33 | num_batches_tracked_key = prefix + 'num_batches_tracked'
34 | if num_batches_tracked_key in state_dict:
35 | del state_dict[num_batches_tracked_key]
36 |
37 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
38 | state_dict, prefix, local_metadata, strict,
39 | missing_keys, unexpected_keys, error_msgs)
40 |
41 | def forward(self, x):
42 | # move reshapes to the beginning
43 | # to make it fuser-friendly
44 | w = self.weight.reshape(1, -1, 1, 1)
45 | b = self.bias.reshape(1, -1, 1, 1)
46 | rv = self.running_var.reshape(1, -1, 1, 1)
47 | rm = self.running_mean.reshape(1, -1, 1, 1)
48 | eps = 1e-5
49 | scale = w * (rv + eps).rsqrt()
50 | bias = b - rm * scale
51 | return x * scale + bias
52 |
53 |
54 | class BackboneBase(nn.Module):
55 |
56 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
57 | super().__init__()
58 | for name, parameter in backbone.named_parameters():
59 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
60 | parameter.requires_grad_(False)
61 | if return_interm_layers:
62 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
63 | else:
64 | return_layers = {'layer4': "0"}
65 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
66 | self.num_channels = num_channels
67 |
68 | def forward(self, tensor_list: NestedTensor):
69 | xs = self.body(tensor_list.tensors)
70 | out: Dict[str, NestedTensor] = {}
71 | for name, x in xs.items():
72 | m = tensor_list.mask
73 | assert m is not None
74 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
75 | out[name] = NestedTensor(x, mask)
76 | return out
77 |
78 |
79 | class Backbone(BackboneBase):
80 | """ResNet backbone with frozen BatchNorm."""
81 | def __init__(self, name: str,
82 | train_backbone: bool,
83 | return_interm_layers: bool,
84 | dilation: bool):
85 | backbone = getattr(torchvision.models, name)(
86 | replace_stride_with_dilation=[False, False, dilation],
87 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
88 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
89 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
90 |
91 |
92 | class Joiner(nn.Sequential):
93 | def __init__(self, backbone, position_embedding):
94 | super().__init__(backbone, position_embedding)
95 |
96 | def forward(self, tensor_list: NestedTensor):
97 | xs = self[0](tensor_list)
98 | out: List[NestedTensor] = []
99 | pos = []
100 | for name, x in xs.items():
101 | out.append(x)
102 | # position encoding
103 | pos.append(self[1](x).to(x.tensors.dtype))
104 |
105 | return out, pos
106 |
107 |
108 | def build_backbone(config):
109 | position_embedding = build_position_encoding(config)
110 | train_backbone = config.lr_backbone > 0
111 | return_interm_layers = False
112 | backbone = Backbone(config.backbone, train_backbone, return_interm_layers, config.dilation)
113 | model = Joiner(backbone, position_embedding)
114 | model.num_channels = backbone.num_channels
115 | return model
--------------------------------------------------------------------------------
/models/caption.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from .utils import NestedTensor, nested_tensor_from_tensor_list
6 | from .backbone import build_backbone
7 | from .transformer import build_transformer
8 |
9 |
10 | class Caption(nn.Module):
11 | def __init__(self, backbone, transformer, hidden_dim, vocab_size):
12 | super().__init__()
13 | self.backbone = backbone
14 | self.input_proj = nn.Conv2d(
15 | backbone.num_channels, hidden_dim, kernel_size=1)
16 | self.transformer = transformer
17 | self.mlp = MLP(hidden_dim, 512, vocab_size, 3)
18 |
19 | def forward(self, samples, target, target_mask):
20 | if not isinstance(samples, NestedTensor):
21 | samples = nested_tensor_from_tensor_list(samples)
22 |
23 | features, pos = self.backbone(samples)
24 | src, mask = features[-1].decompose()
25 |
26 | assert mask is not None
27 |
28 | hs = self.transformer(self.input_proj(src), mask,
29 | pos[-1], target, target_mask)
30 | out = self.mlp(hs.permute(1, 0, 2))
31 | return out
32 |
33 |
34 | class MLP(nn.Module):
35 | """ Very simple multi-layer perceptron (also called FFN)"""
36 |
37 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
38 | super().__init__()
39 | self.num_layers = num_layers
40 | h = [hidden_dim] * (num_layers - 1)
41 | self.layers = nn.ModuleList(nn.Linear(n, k)
42 | for n, k in zip([input_dim] + h, h + [output_dim]))
43 |
44 | def forward(self, x):
45 | for i, layer in enumerate(self.layers):
46 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
47 | return x
48 |
49 |
50 | def build_model(config):
51 | backbone = build_backbone(config)
52 | transformer = build_transformer(config)
53 |
54 | model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size)
55 | criterion = torch.nn.CrossEntropyLoss()
56 |
57 | return model, criterion
--------------------------------------------------------------------------------
/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import math
3 | import torch
4 | from torch import nn
5 |
6 | from .utils import NestedTensor
7 |
8 |
9 | class PositionEmbeddingSine(nn.Module):
10 | """
11 | This is a more standard version of the position embedding, very similar to the one
12 | used by the Attention is all you need paper, generalized to work on images.
13 | """
14 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
15 | super().__init__()
16 | self.num_pos_feats = num_pos_feats
17 | self.temperature = temperature
18 | self.normalize = normalize
19 | if scale is not None and normalize is False:
20 | raise ValueError("normalize should be True if scale is passed")
21 | if scale is None:
22 | scale = 2 * math.pi
23 | self.scale = scale
24 |
25 | def forward(self, tensor_list: NestedTensor):
26 | x = tensor_list.tensors
27 | mask = tensor_list.mask
28 | assert mask is not None
29 | not_mask = ~mask
30 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
31 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
32 | if self.normalize:
33 | eps = 1e-6
34 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
35 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
36 |
37 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
38 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
39 |
40 | pos_x = x_embed[:, :, :, None] / dim_t
41 | pos_y = y_embed[:, :, :, None] / dim_t
42 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
43 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
44 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
45 | return pos
46 |
47 |
48 | class PositionEmbeddingLearned(nn.Module):
49 | """
50 | Absolute pos embedding, learned.
51 | """
52 | def __init__(self, num_pos_feats=256):
53 | super().__init__()
54 | self.row_embed = nn.Embedding(50, num_pos_feats)
55 | self.col_embed = nn.Embedding(50, num_pos_feats)
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.uniform_(self.row_embed.weight)
60 | nn.init.uniform_(self.col_embed.weight)
61 |
62 | def forward(self, tensor_list: NestedTensor):
63 | x = tensor_list.tensors
64 | h, w = x.shape[-2:]
65 | i = torch.arange(w, device=x.device)
66 | j = torch.arange(h, device=x.device)
67 | x_emb = self.col_embed(i)
68 | y_emb = self.row_embed(j)
69 | pos = torch.cat([
70 | x_emb.unsqueeze(0).repeat(h, 1, 1),
71 | y_emb.unsqueeze(1).repeat(1, w, 1),
72 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
73 | return pos
74 |
75 |
76 | def build_position_encoding(config):
77 | N_steps = config.hidden_dim // 2
78 | if config.position_embedding in ('v2', 'sine'):
79 | # TODO find a better way of exposing other arguments
80 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
81 | elif config.position_embedding in ('v3', 'learned'):
82 | position_embedding = PositionEmbeddingLearned(N_steps)
83 | else:
84 | raise ValueError(f"not supported {config.position_embedding}")
85 |
86 | return position_embedding
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import copy
3 | from typing import Optional, List
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn, Tensor
8 |
9 |
10 | class Transformer(nn.Module):
11 |
12 | def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6,
13 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
14 | activation="relu", normalize_before=False,
15 | return_intermediate_dec=False):
16 | super().__init__()
17 |
18 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
19 | dropout, activation, normalize_before)
20 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
21 | self.encoder = TransformerEncoder(
22 | encoder_layer, num_encoder_layers, encoder_norm)
23 |
24 | self.embeddings = DecoderEmbeddings(config)
25 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
26 | dropout, activation, normalize_before)
27 | decoder_norm = nn.LayerNorm(d_model)
28 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
29 | return_intermediate=return_intermediate_dec)
30 |
31 | self._reset_parameters()
32 |
33 | self.d_model = d_model
34 | self.nhead = nhead
35 |
36 | def _reset_parameters(self):
37 | for p in self.parameters():
38 | if p.dim() > 1:
39 | nn.init.xavier_uniform_(p)
40 |
41 | def forward(self, src, mask, pos_embed, tgt, tgt_mask):
42 | # flatten NxCxHxW to HWxNxC
43 | bs, c, h, w = src.shape
44 | src = src.flatten(2).permute(2, 0, 1)
45 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
46 | mask = mask.flatten(1)
47 |
48 | tgt = self.embeddings(tgt).permute(1, 0, 2)
49 | query_embed = self.embeddings.position_embeddings.weight.unsqueeze(1)
50 | query_embed = query_embed.repeat(1, bs, 1)
51 |
52 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
53 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, tgt_key_padding_mask=tgt_mask,
54 | pos=pos_embed, query_pos=query_embed,
55 | tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device))
56 |
57 | return hs
58 |
59 |
60 | class TransformerEncoder(nn.Module):
61 |
62 | def __init__(self, encoder_layer, num_layers, norm=None):
63 | super().__init__()
64 | self.layers = _get_clones(encoder_layer, num_layers)
65 | self.num_layers = num_layers
66 | self.norm = norm
67 |
68 | def forward(self, src,
69 | mask: Optional[Tensor] = None,
70 | src_key_padding_mask: Optional[Tensor] = None,
71 | pos: Optional[Tensor] = None):
72 | output = src
73 |
74 | for layer in self.layers:
75 | output = layer(output, src_mask=mask,
76 | src_key_padding_mask=src_key_padding_mask, pos=pos)
77 |
78 | if self.norm is not None:
79 | output = self.norm(output)
80 |
81 | return output
82 |
83 |
84 | class TransformerDecoder(nn.Module):
85 |
86 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
87 | super().__init__()
88 | self.layers = _get_clones(decoder_layer, num_layers)
89 | self.num_layers = num_layers
90 | self.norm = norm
91 | self.return_intermediate = return_intermediate
92 |
93 | def forward(self, tgt, memory,
94 | tgt_mask: Optional[Tensor] = None,
95 | memory_mask: Optional[Tensor] = None,
96 | tgt_key_padding_mask: Optional[Tensor] = None,
97 | memory_key_padding_mask: Optional[Tensor] = None,
98 | pos: Optional[Tensor] = None,
99 | query_pos: Optional[Tensor] = None):
100 | output = tgt
101 |
102 | intermediate = []
103 |
104 | for layer in self.layers:
105 | output = layer(output, memory, tgt_mask=tgt_mask,
106 | memory_mask=memory_mask,
107 | tgt_key_padding_mask=tgt_key_padding_mask,
108 | memory_key_padding_mask=memory_key_padding_mask,
109 | pos=pos, query_pos=query_pos)
110 | if self.return_intermediate:
111 | intermediate.append(self.norm(output))
112 |
113 | if self.norm is not None:
114 | output = self.norm(output)
115 | if self.return_intermediate:
116 | intermediate.pop()
117 | intermediate.append(output)
118 |
119 | if self.return_intermediate:
120 | return torch.stack(intermediate)
121 |
122 | return output
123 |
124 |
125 | class TransformerEncoderLayer(nn.Module):
126 |
127 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
128 | activation="relu", normalize_before=False):
129 | super().__init__()
130 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
131 | # Implementation of Feedforward model
132 | self.linear1 = nn.Linear(d_model, dim_feedforward)
133 | self.dropout = nn.Dropout(dropout)
134 | self.linear2 = nn.Linear(dim_feedforward, d_model)
135 |
136 | self.norm1 = nn.LayerNorm(d_model)
137 | self.norm2 = nn.LayerNorm(d_model)
138 | self.dropout1 = nn.Dropout(dropout)
139 | self.dropout2 = nn.Dropout(dropout)
140 |
141 | self.activation = _get_activation_fn(activation)
142 | self.normalize_before = normalize_before
143 |
144 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
145 | return tensor if pos is None else tensor + pos
146 |
147 | def forward_post(self,
148 | src,
149 | src_mask: Optional[Tensor] = None,
150 | src_key_padding_mask: Optional[Tensor] = None,
151 | pos: Optional[Tensor] = None):
152 | q = k = self.with_pos_embed(src, pos)
153 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
154 | key_padding_mask=src_key_padding_mask)[0]
155 | src = src + self.dropout1(src2)
156 | src = self.norm1(src)
157 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
158 | src = src + self.dropout2(src2)
159 | src = self.norm2(src)
160 | return src
161 |
162 | def forward_pre(self, src,
163 | src_mask: Optional[Tensor] = None,
164 | src_key_padding_mask: Optional[Tensor] = None,
165 | pos: Optional[Tensor] = None):
166 | src2 = self.norm1(src)
167 | q = k = self.with_pos_embed(src2, pos)
168 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
169 | key_padding_mask=src_key_padding_mask)[0]
170 | src = src + self.dropout1(src2)
171 | src2 = self.norm2(src)
172 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
173 | src = src + self.dropout2(src2)
174 | return src
175 |
176 | def forward(self, src,
177 | src_mask: Optional[Tensor] = None,
178 | src_key_padding_mask: Optional[Tensor] = None,
179 | pos: Optional[Tensor] = None):
180 | if self.normalize_before:
181 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
182 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
183 |
184 |
185 | class TransformerDecoderLayer(nn.Module):
186 |
187 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
188 | activation="relu", normalize_before=False):
189 | super().__init__()
190 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
191 | self.multihead_attn = nn.MultiheadAttention(
192 | d_model, nhead, dropout=dropout)
193 | # Implementation of Feedforward model
194 | self.linear1 = nn.Linear(d_model, dim_feedforward)
195 | self.dropout = nn.Dropout(dropout)
196 | self.linear2 = nn.Linear(dim_feedforward, d_model)
197 |
198 | self.norm1 = nn.LayerNorm(d_model)
199 | self.norm2 = nn.LayerNorm(d_model)
200 | self.norm3 = nn.LayerNorm(d_model)
201 | self.dropout1 = nn.Dropout(dropout)
202 | self.dropout2 = nn.Dropout(dropout)
203 | self.dropout3 = nn.Dropout(dropout)
204 |
205 | self.activation = _get_activation_fn(activation)
206 | self.normalize_before = normalize_before
207 |
208 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
209 | return tensor if pos is None else tensor + pos
210 |
211 | def forward_post(self, tgt, memory,
212 | tgt_mask: Optional[Tensor] = None,
213 | memory_mask: Optional[Tensor] = None,
214 | tgt_key_padding_mask: Optional[Tensor] = None,
215 | memory_key_padding_mask: Optional[Tensor] = None,
216 | pos: Optional[Tensor] = None,
217 | query_pos: Optional[Tensor] = None):
218 | q = k = self.with_pos_embed(tgt, query_pos)
219 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
220 | key_padding_mask=tgt_key_padding_mask)[0]
221 | tgt = tgt + self.dropout1(tgt2)
222 | tgt = self.norm1(tgt)
223 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
224 | key=self.with_pos_embed(memory, pos),
225 | value=memory, attn_mask=memory_mask,
226 | key_padding_mask=memory_key_padding_mask)[0]
227 | tgt = tgt + self.dropout2(tgt2)
228 | tgt = self.norm2(tgt)
229 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
230 | tgt = tgt + self.dropout3(tgt2)
231 | tgt = self.norm3(tgt)
232 | return tgt
233 |
234 | def forward_pre(self, tgt, memory,
235 | tgt_mask: Optional[Tensor] = None,
236 | memory_mask: Optional[Tensor] = None,
237 | tgt_key_padding_mask: Optional[Tensor] = None,
238 | memory_key_padding_mask: Optional[Tensor] = None,
239 | pos: Optional[Tensor] = None,
240 | query_pos: Optional[Tensor] = None):
241 | tgt2 = self.norm1(tgt)
242 | q = k = self.with_pos_embed(tgt2, query_pos)
243 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
244 | key_padding_mask=tgt_key_padding_mask)[0]
245 | tgt = tgt + self.dropout1(tgt2)
246 | tgt2 = self.norm2(tgt)
247 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
248 | key=self.with_pos_embed(memory, pos),
249 | value=memory, attn_mask=memory_mask,
250 | key_padding_mask=memory_key_padding_mask)[0]
251 | tgt = tgt + self.dropout2(tgt2)
252 | tgt2 = self.norm3(tgt)
253 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
254 | tgt = tgt + self.dropout3(tgt2)
255 | return tgt
256 |
257 | def forward(self, tgt, memory,
258 | tgt_mask: Optional[Tensor] = None,
259 | memory_mask: Optional[Tensor] = None,
260 | tgt_key_padding_mask: Optional[Tensor] = None,
261 | memory_key_padding_mask: Optional[Tensor] = None,
262 | pos: Optional[Tensor] = None,
263 | query_pos: Optional[Tensor] = None):
264 | if self.normalize_before:
265 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
266 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
267 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
268 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
269 |
270 |
271 | class DecoderEmbeddings(nn.Module):
272 | def __init__(self, config):
273 | super().__init__()
274 | self.word_embeddings = nn.Embedding(
275 | config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id)
276 | self.position_embeddings = nn.Embedding(
277 | config.max_position_embeddings, config.hidden_dim
278 | )
279 |
280 | self.LayerNorm = torch.nn.LayerNorm(
281 | config.hidden_dim, eps=config.layer_norm_eps)
282 | self.dropout = nn.Dropout(config.dropout)
283 |
284 | def forward(self, x):
285 | input_shape = x.size()
286 | seq_length = input_shape[1]
287 | device = x.device
288 |
289 | position_ids = torch.arange(
290 | seq_length, dtype=torch.long, device=device)
291 | position_ids = position_ids.unsqueeze(0).expand(input_shape)
292 |
293 | input_embeds = self.word_embeddings(x)
294 | position_embeds = self.position_embeddings(position_ids)
295 |
296 | embeddings = input_embeds + position_embeds
297 | embeddings = self.LayerNorm(embeddings)
298 | embeddings = self.dropout(embeddings)
299 |
300 | return embeddings
301 |
302 |
303 | def _get_clones(module, N):
304 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
305 |
306 |
307 | def _get_activation_fn(activation):
308 | """Return an activation function given a string"""
309 | if activation == "relu":
310 | return F.relu
311 | if activation == "gelu":
312 | return F.gelu
313 | if activation == "glu":
314 | return F.glu
315 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
316 |
317 |
318 | def generate_square_subsequent_mask(sz):
319 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
320 | Unmasked positions are filled with float(0.0).
321 | """
322 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
323 | mask = mask.float().masked_fill(mask == 0, float(
324 | '-inf')).masked_fill(mask == 1, float(0.0))
325 | return mask
326 |
327 |
328 | def build_transformer(config):
329 | return Transformer(
330 | config,
331 | d_model=config.hidden_dim,
332 | dropout=config.dropout,
333 | nhead=config.nheads,
334 | dim_feedforward=config.dim_feedforward,
335 | num_encoder_layers=config.enc_layers,
336 | num_decoder_layers=config.dec_layers,
337 | normalize_before=config.pre_norm,
338 | return_intermediate_dec=False,
339 | )
340 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from typing import List, Optional
3 |
4 | import torch
5 | import torch.distributed as dist
6 | from torch import Tensor
7 |
8 |
9 | def _max_by_axis(the_list):
10 | # type: (List[List[int]]) -> List[int]
11 | maxes = the_list[0]
12 | for sublist in the_list[1:]:
13 | for index, item in enumerate(sublist):
14 | maxes[index] = max(maxes[index], item)
15 | return maxes
16 |
17 |
18 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
19 | # TODO make this more general
20 | if tensor_list[0].ndim == 3:
21 | # TODO make it support different-sized images
22 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
23 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
24 | batch_shape = [len(tensor_list)] + max_size
25 | b, c, h, w = batch_shape
26 | dtype = tensor_list[0].dtype
27 | device = tensor_list[0].device
28 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
29 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
30 | for img, pad_img, m in zip(tensor_list, tensor, mask):
31 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
32 | m[: img.shape[1], :img.shape[2]] = False
33 | else:
34 | raise ValueError('not supported')
35 | return NestedTensor(tensor, mask)
36 |
37 |
38 | class NestedTensor(object):
39 | def __init__(self, tensors, mask: Optional[Tensor]):
40 | self.tensors = tensors
41 | self.mask = mask
42 |
43 | def to(self, device):
44 | # type: (Device) -> NestedTensor # noqa
45 | cast_tensor = self.tensors.to(device)
46 | mask = self.mask
47 | if mask is not None:
48 | assert mask is not None
49 | cast_mask = mask.to(device)
50 | else:
51 | cast_mask = None
52 | return NestedTensor(cast_tensor, cast_mask)
53 |
54 | def decompose(self):
55 | return self.tensors, self.mask
56 |
57 | def __repr__(self):
58 | return str(self.tensors)
59 |
60 |
61 | def is_dist_avail_and_initialized():
62 | if not dist.is_available():
63 | return False
64 | if not dist.is_initialized():
65 | return False
66 | return True
67 |
68 |
69 | def get_rank():
70 | if not is_dist_avail_and_initialized():
71 | return 0
72 | return dist.get_rank()
73 |
74 |
75 | def is_main_process():
76 | return get_rank() == 0
77 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import BertTokenizer
4 | from PIL import Image
5 | import argparse
6 |
7 | from models import caption
8 | from datasets import coco, utils
9 | from configuration import Config
10 | import os
11 |
12 | parser = argparse.ArgumentParser(description='Image Captioning')
13 | parser.add_argument('--path', type=str, help='path to image', required=True)
14 | parser.add_argument('--v', type=str, help='version', default='v3')
15 | parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
16 | args = parser.parse_args()
17 | image_path = args.path
18 | version = args.v
19 | checkpoint_path = args.checkpoint
20 |
21 | config = Config()
22 |
23 | if version == 'v1':
24 | model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
25 | elif version == 'v2':
26 | model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
27 | elif version == 'v3':
28 | model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
29 | else:
30 | print("Checking for checkpoint.")
31 | if checkpoint_path is None:
32 | raise NotImplementedError('No model to chose from!')
33 | else:
34 | if not os.path.exists(checkpoint_path):
35 | raise NotImplementedError('Give valid checkpoint path')
36 | print("Found checkpoint! Loading!")
37 | model,_ = caption.build_model(config)
38 | print("Loading Checkpoint...")
39 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
40 | model.load_state_dict(checkpoint['model'])
41 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
42 |
43 | start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
44 | end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)
45 |
46 | image = Image.open(image_path)
47 | image = coco.val_transform(image)
48 | image = image.unsqueeze(0)
49 |
50 |
51 | def create_caption_and_mask(start_token, max_length):
52 | caption_template = torch.zeros((1, max_length), dtype=torch.long)
53 | mask_template = torch.ones((1, max_length), dtype=torch.bool)
54 |
55 | caption_template[:, 0] = start_token
56 | mask_template[:, 0] = False
57 |
58 | return caption_template, mask_template
59 |
60 |
61 | caption, cap_mask = create_caption_and_mask(
62 | start_token, config.max_position_embeddings)
63 |
64 |
65 | @torch.no_grad()
66 | def evaluate():
67 | model.eval()
68 | for i in range(config.max_position_embeddings - 1):
69 | predictions = model(image, caption, cap_mask)
70 | predictions = predictions[:, i, :]
71 | predicted_id = torch.argmax(predictions, axis=-1)
72 |
73 | if predicted_id[0] == 102:
74 | return caption
75 |
76 | caption[:, i+1] = predicted_id[0]
77 | cap_mask[:, i+1] = False
78 |
79 | return caption
80 |
81 |
82 | output = evaluate()
83 | result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
84 | #result = tokenizer.decode(output[0], skip_special_tokens=True)
85 | print(result.capitalize())
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | transformers
5 | tqdm
--------------------------------------------------------------------------------