├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── coca.png
├── coca_pytorch
├── __init__.py
└── coca_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## CoCa - Pytorch
4 |
5 | Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.
6 |
7 | This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)
8 |
9 | Update: CoCa has been trained by the good folks over at OpenClip
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install coca-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | First install the `vit-pytorch` for the image encoder, which needs to be pretrained
20 |
21 | ```bash
22 | $ pip install vit-pytorch>=0.40.2
23 | ```
24 |
25 | Then
26 |
27 | ```python
28 | import torch
29 |
30 | # import vision transformer
31 |
32 | from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
33 | from vit_pytorch.extractor import Extractor
34 |
35 | vit = SimpleViT(
36 | image_size = 256,
37 | patch_size = 32,
38 | num_classes = 1000,
39 | dim = 1024,
40 | depth = 6,
41 | heads = 16,
42 | mlp_dim = 2048,
43 | patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794
44 | )
45 |
46 | vit = Extractor(vit, return_embeddings_only = True, detach = False)
47 |
48 | # extractor will enable it so the vision transformer returns its embeddings
49 |
50 | # import CoCa and instantiate it
51 |
52 | from coca_pytorch.coca_pytorch import CoCa
53 |
54 | coca = CoCa(
55 | dim = 512, # model dimension
56 | img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
57 | image_dim = 1024, # image embedding dimension, if not the same as model dimensions
58 | num_tokens = 20000, # number of text tokens
59 | unimodal_depth = 6, # depth of the unimodal transformer
60 | multimodal_depth = 6, # depth of the multimodal transformer
61 | dim_head = 64, # dimension per attention head
62 | heads = 8, # number of attention heads
63 | caption_loss_weight = 1., # weight on the autoregressive caption loss
64 | contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings
65 | ).cuda()
66 |
67 | # mock text and images
68 |
69 | text = torch.randint(0, 20000, (4, 512)).cuda()
70 | images = torch.randn(4, 3, 256, 256).cuda()
71 |
72 | # train by giving CoCa your text and images with `return_loss = True`
73 |
74 | loss = coca(
75 | text = text,
76 | images = images,
77 | return_loss = True # set this to True to get the full caption + contrastive loss
78 | )
79 |
80 | loss.backward()
81 |
82 | # do the above for as much text and images...
83 | # then you can get the caption logits as so
84 |
85 | logits = coca(
86 | text = text,
87 | images = images
88 | ) # (4, 512, 20000)
89 |
90 | # and the CLIP-like text and image embeddings as
91 |
92 | text_embeds, image_embeds = coca(
93 | text = text,
94 | images = images,
95 | return_embeddings = True
96 | ) # (4, 512), (4, 512)
97 | ```
98 |
99 | ## Citations
100 |
101 | ```bibtex
102 | @inproceedings{Yu2022CoCaCC,
103 | title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
104 | author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
105 | year = {2022}
106 | }
107 | ```
108 |
109 | ```bibtex
110 | @inproceedings{Chowdhery2022PaLMSL,
111 | title = {PaLM: Scaling Language Modeling with Pathways},
112 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
113 | year = {2022}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/coca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/CoCa-pytorch/edee92c74e311ccfa4a0024412fd991c98aff5fd/coca.png
--------------------------------------------------------------------------------
/coca_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from coca_pytorch.coca_pytorch import CoCa
2 |
--------------------------------------------------------------------------------
/coca_pytorch/coca_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import einsum, nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Function
5 | import torch.distributed as dist
6 |
7 | from einops import rearrange, repeat
8 |
9 | # helper functions
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def default(val, d):
15 | return val if exists(val) else d
16 |
17 | # distributed
18 |
19 | def pad_dim_to(t, length, dim = 0):
20 | pad_length = length - t.shape[dim]
21 | zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
22 | return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
23 |
24 | def all_gather_variable_batch(t):
25 | device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
26 |
27 | size = torch.tensor(t.shape[0], device = device, dtype = torch.long)
28 | sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
29 | dist.all_gather(sizes, size)
30 |
31 | sizes = torch.stack(sizes)
32 | max_size = sizes.amax().item()
33 |
34 | padded_t = pad_dim_to(t, max_size, dim = 0)
35 | gathered_tensors = [torch.empty_like(padded_t, device = device, dtype = padded_t.dtype) for i in range(world_size)]
36 | dist.all_gather(gathered_tensors, padded_t)
37 |
38 | gathered_tensor = torch.cat(gathered_tensors)
39 | seq = torch.arange(max_size, device = device)
40 |
41 | mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
42 | mask = rearrange(mask, 'i j -> (i j)')
43 |
44 | gathered_tensor = gathered_tensor[mask]
45 | sizes = sizes.tolist()
46 |
47 | return gathered_tensor, sizes
48 |
49 | class AllGather(Function):
50 | @staticmethod
51 | def forward(ctx, x):
52 | assert dist.is_initialized() and dist.get_world_size() > 1
53 | x, batch_sizes = all_gather_variable_batch(x)
54 | ctx.batch_sizes = batch_sizes
55 | return x
56 |
57 | @staticmethod
58 | def backward(ctx, grads):
59 | batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
60 | grads_by_rank = grads.split(batch_sizes, dim = 0)
61 | return grads_by_rank[rank]
62 |
63 | all_gather = AllGather.apply
64 |
65 |
66 | # normalization
67 | # they use layernorm without bias, something that pytorch does not offer
68 |
69 |
70 | class LayerNorm(nn.Module):
71 | def __init__(self, dim):
72 | super().__init__()
73 | self.gamma = nn.Parameter(torch.ones(dim))
74 | self.register_buffer("beta", torch.zeros(dim))
75 |
76 | def forward(self, x):
77 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
78 |
79 | # residual
80 |
81 |
82 | class Residual(nn.Module):
83 | def __init__(self, fn):
84 | super().__init__()
85 | self.fn = fn
86 |
87 | def forward(self, x, *args, **kwargs):
88 | return self.fn(x, *args, **kwargs) + x
89 |
90 | # to latents
91 |
92 |
93 | class EmbedToLatents(nn.Module):
94 | def __init__(self, dim, dim_latents):
95 | super().__init__()
96 | self.to_latents = nn.Linear(dim, dim_latents, bias=False)
97 |
98 | def forward(self, x):
99 | latents = self.to_latents(x)
100 | return F.normalize(latents, dim=-1)
101 |
102 | # rotary positional embedding
103 | # https://arxiv.org/abs/2104.09864
104 |
105 |
106 | class RotaryEmbedding(nn.Module):
107 | def __init__(self, dim):
108 | super().__init__()
109 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
110 | self.register_buffer("inv_freq", inv_freq)
111 |
112 | def forward(self, max_seq_len, *, device):
113 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
114 | freqs = einsum("i , j -> i j", seq, self.inv_freq)
115 | return torch.cat((freqs, freqs), dim=-1)
116 |
117 |
118 | def rotate_half(x):
119 | x = rearrange(x, "... (j d) -> ... j d", j=2)
120 | x1, x2 = x.unbind(dim=-2)
121 | return torch.cat((-x2, x1), dim=-1)
122 |
123 |
124 | def apply_rotary_pos_emb(pos, t):
125 | return (t * pos.cos()) + (rotate_half(t) * pos.sin())
126 |
127 |
128 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
129 | # https://arxiv.org/abs/2002.05202
130 |
131 |
132 | class SwiGLU(nn.Module):
133 | def forward(self, x):
134 | x, gate = x.chunk(2, dim=-1)
135 | return F.silu(gate) * x
136 |
137 |
138 | # parallel attention and feedforward with residual
139 | # discovered by Wang et al + EleutherAI from GPT-J fame
140 |
141 |
142 | class ParallelTransformerBlock(nn.Module):
143 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
144 | super().__init__()
145 | self.norm = LayerNorm(dim)
146 |
147 | attn_inner_dim = dim_head * heads
148 | ff_inner_dim = dim * ff_mult
149 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
150 |
151 | self.heads = heads
152 | self.scale = dim_head**-0.5
153 | self.rotary_emb = RotaryEmbedding(dim_head)
154 |
155 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
156 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
157 |
158 | self.ff_out = nn.Sequential(
159 | SwiGLU(),
160 | nn.Linear(ff_inner_dim, dim, bias=False)
161 | )
162 |
163 | # for caching causal mask and rotary embeddings
164 |
165 | self.mask = None
166 | self.pos_emb = None
167 |
168 | def get_mask(self, n, device):
169 | if self.mask is not None and self.mask.shape[-1] >= n:
170 | return self.mask[:n, :n].to(device)
171 |
172 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
173 | self.mask = mask
174 | return mask
175 |
176 | def get_rotary_embedding(self, n, device):
177 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
178 | return self.pos_emb[:n].to(device)
179 |
180 | pos_emb = self.rotary_emb(n, device=device)
181 | self.pos_emb = pos_emb
182 | return pos_emb
183 |
184 | def forward(self, x, attn_mask=None):
185 | """
186 | einstein notation
187 | b - batch
188 | h - heads
189 | n, i, j - sequence length (base sequence length, source, target)
190 | d - feature dimension
191 | """
192 |
193 | n, device, h = x.shape[1], x.device, self.heads
194 |
195 | # pre layernorm
196 |
197 | x = self.norm(x)
198 |
199 | # attention queries, keys, values, and feedforward inner
200 |
201 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
202 |
203 | # split heads
204 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper
205 | # they found no performance loss past a certain scale, and more efficient decoding obviously
206 | # https://arxiv.org/abs/1911.02150
207 |
208 | q = rearrange(q, "b n (h d) -> b h n d", h=h)
209 |
210 | # rotary embeddings
211 |
212 | positions = self.get_rotary_embedding(n, device)
213 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
214 |
215 | # scale
216 |
217 | q = q * self.scale
218 |
219 | # similarity
220 |
221 | sim = einsum("b h i d, b j d -> b h i j", q, k)
222 |
223 | # causal mask
224 |
225 | causal_mask = self.get_mask(n, device)
226 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
227 |
228 | # extra attention mask - for masking out attention from text CLS token to padding
229 |
230 | if exists(attn_mask):
231 | attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
232 | sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
233 |
234 | # attention
235 |
236 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
237 | attn = sim.softmax(dim=-1)
238 |
239 | # aggregate values
240 |
241 | out = einsum("b h i j, b j d -> b h i d", attn, v)
242 |
243 | # merge heads
244 |
245 | out = rearrange(out, "b h n d -> b n (h d)")
246 | return self.attn_out(out) + self.ff_out(ff)
247 |
248 | # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
249 |
250 | class CrossAttention(nn.Module):
251 | def __init__(
252 | self,
253 | dim,
254 | *,
255 | context_dim=None,
256 | dim_head=64,
257 | heads=8,
258 | parallel_ff=False,
259 | ff_mult=4,
260 | norm_context=False
261 | ):
262 | super().__init__()
263 | self.heads = heads
264 | self.scale = dim_head ** -0.5
265 | inner_dim = heads * dim_head
266 | context_dim = default(context_dim, dim)
267 |
268 | self.norm = LayerNorm(dim)
269 | self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
270 |
271 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
272 | self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
273 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
274 |
275 | # whether to have parallel feedforward
276 |
277 | ff_inner_dim = ff_mult * dim
278 |
279 | self.ff = nn.Sequential(
280 | nn.Linear(dim, ff_inner_dim * 2, bias=False),
281 | SwiGLU(),
282 | nn.Linear(ff_inner_dim, dim, bias=False)
283 | ) if parallel_ff else None
284 |
285 | def forward(self, x, context):
286 | """
287 | einstein notation
288 | b - batch
289 | h - heads
290 | n, i, j - sequence length (base sequence length, source, target)
291 | d - feature dimension
292 | """
293 |
294 | # pre-layernorm, for queries and context
295 |
296 | x = self.norm(x)
297 | context = self.context_norm(context)
298 |
299 | # get queries
300 |
301 | q = self.to_q(x)
302 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
303 |
304 | # scale
305 |
306 | q = q * self.scale
307 |
308 | # get key / values
309 |
310 | k, v = self.to_kv(context).chunk(2, dim=-1)
311 |
312 | # query / key similarity
313 |
314 | sim = einsum('b h i d, b j d -> b h i j', q, k)
315 |
316 | # attention
317 |
318 | sim = sim - sim.amax(dim=-1, keepdim=True)
319 | attn = sim.softmax(dim=-1)
320 |
321 | # aggregate
322 |
323 | out = einsum('b h i j, b j d -> b h i d', attn, v)
324 |
325 | # merge and combine heads
326 |
327 | out = rearrange(out, 'b h n d -> b n (h d)')
328 | out = self.to_out(out)
329 |
330 | # add parallel feedforward (for multimodal layers)
331 |
332 | if exists(self.ff):
333 | out = out + self.ff(x)
334 |
335 | return out
336 |
337 | # transformer
338 |
339 |
340 | class CoCa(nn.Module):
341 | def __init__(
342 | self,
343 | *,
344 | dim,
345 | num_tokens,
346 | unimodal_depth,
347 | multimodal_depth,
348 | dim_latents = None,
349 | image_dim = None,
350 | num_img_queries=256,
351 | dim_head=64,
352 | heads=8,
353 | ff_mult=4,
354 | img_encoder=None,
355 | caption_loss_weight=1.,
356 | contrastive_loss_weight=1.,
357 | pad_id=0
358 | ):
359 | super().__init__()
360 | self.dim = dim
361 |
362 | self.pad_id = pad_id
363 | self.caption_loss_weight = caption_loss_weight
364 | self.contrastive_loss_weight = contrastive_loss_weight
365 |
366 | # token embeddings
367 |
368 | self.token_emb = nn.Embedding(num_tokens, dim)
369 | self.text_cls_token = nn.Parameter(torch.randn(dim))
370 |
371 | # image encoder
372 |
373 | self.img_encoder = img_encoder
374 |
375 | # attention pooling for image tokens
376 |
377 | self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning
378 | self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads, norm_context=True)
379 |
380 | self.img_attn_pool_norm = LayerNorm(dim)
381 | self.text_cls_norm = LayerNorm(dim)
382 |
383 | # to latents
384 |
385 | dim_latents = default(dim_latents, dim)
386 | self.img_to_latents = EmbedToLatents(dim, dim_latents)
387 | self.text_to_latents = EmbedToLatents(dim, dim_latents)
388 |
389 | # contrastive learning temperature
390 |
391 | self.temperature = nn.Parameter(torch.Tensor([1.]))
392 |
393 | # unimodal layers
394 |
395 | self.unimodal_layers = nn.ModuleList([])
396 | for ind in range(unimodal_depth):
397 | self.unimodal_layers.append(
398 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
399 | )
400 |
401 | # multimodal layers
402 |
403 | self.multimodal_layers = nn.ModuleList([])
404 | for ind in range(multimodal_depth):
405 | self.multimodal_layers.append(nn.ModuleList([
406 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
407 | Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult))
408 | ]))
409 |
410 | # to logits
411 |
412 | self.to_logits = nn.Sequential(
413 | LayerNorm(dim),
414 | nn.Linear(dim, num_tokens, bias=False)
415 | )
416 |
417 | # they used embedding weight tied projection out to logits, not common, but works
418 | self.to_logits[-1].weight = self.token_emb.weight
419 | nn.init.normal_(self.token_emb.weight, std=0.02)
420 |
421 | # whether in data parallel setting
422 | self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
423 |
424 | def embed_text(self, text):
425 | batch, device = text.shape[0], text.device
426 |
427 | seq = text.shape[1]
428 |
429 | text_tokens = self.token_emb(text)
430 |
431 | # append text cls tokens
432 |
433 | text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
434 | text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
435 |
436 | # create specific mask for text cls token at the end
437 | # to prevent it from attending to padding
438 |
439 | cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
440 | attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
441 |
442 | # go through unimodal layers
443 |
444 | for attn_ff in self.unimodal_layers:
445 | text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
446 |
447 | # get text cls token
448 |
449 | text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
450 | text_embeds = self.text_cls_norm(text_cls_tokens)
451 | return text_embeds, text_tokens
452 |
453 | def embed_image(self, images=None, image_tokens=None):
454 | # encode images into embeddings
455 | # with the img_encoder passed in at init
456 | # it can also accept precomputed image tokens
457 |
458 | assert not (exists(images) and exists(image_tokens))
459 |
460 | if exists(images):
461 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
462 | image_tokens = self.img_encoder(images)
463 |
464 | # attention pool image tokens
465 |
466 | img_queries = repeat(self.img_queries, 'n d -> b n d', b=image_tokens.shape[0])
467 | img_queries = self.img_attn_pool(img_queries, image_tokens)
468 | img_queries = self.img_attn_pool_norm(img_queries)
469 |
470 | return img_queries[:, 0], img_queries[:, 1:]
471 |
472 | def forward(
473 | self,
474 | text,
475 | images=None,
476 | image_tokens=None,
477 | labels=None,
478 | return_loss=False,
479 | return_embeddings=False
480 | ):
481 | batch, device = text.shape[0], text.device
482 |
483 | if return_loss and not exists(labels):
484 | text, labels = text[:, :-1], text[:, 1:]
485 |
486 | text_embeds, text_tokens = self.embed_text(text)
487 |
488 | image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens)
489 |
490 | # return embeddings if that is what the researcher wants
491 |
492 | if return_embeddings:
493 | return text_embeds, image_embeds
494 |
495 | # go through multimodal layers
496 |
497 | for attn_ff, cross_attn in self.multimodal_layers:
498 | text_tokens = attn_ff(text_tokens)
499 | text_tokens = cross_attn(text_tokens, image_tokens)
500 |
501 | logits = self.to_logits(text_tokens)
502 |
503 | if not return_loss:
504 | return logits
505 |
506 | # shorthand
507 |
508 | ce = F.cross_entropy
509 |
510 | # calculate caption loss (cross entropy loss)
511 |
512 | logits = rearrange(logits, 'b n c -> b c n')
513 | caption_loss = ce(logits, labels, ignore_index=self.pad_id)
514 | caption_loss = caption_loss * self.caption_loss_weight
515 |
516 | # embedding to latents
517 |
518 | text_latents = self.text_to_latents(text_embeds)
519 | image_latents = self.img_to_latents(image_embeds)
520 |
521 | # maybe distributed all gather
522 |
523 | if self.is_distributed:
524 | latents = torch.stack((text_latents, image_latents), dim = 1)
525 | latents = all_gather(latents)
526 | text_latents, image_latents = latents.unbind(dim = 1)
527 |
528 | # calculate contrastive loss
529 |
530 | sim = einsum('i d, j d -> i j', text_latents, image_latents)
531 | sim = sim * self.temperature.exp()
532 | contrastive_labels = torch.arange(batch, device=device)
533 |
534 | contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
535 | contrastive_loss = contrastive_loss * self.contrastive_loss_weight
536 |
537 | return caption_loss + contrastive_loss
538 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'CoCa-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.0',
7 | license='MIT',
8 | description = 'CoCa, Contrastive Captioners are Image-Text Foundation Models - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/CoCa-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'contrastive learning',
19 | 'multimodal'
20 | ],
21 | install_requires=[
22 | 'einops>=0.4',
23 | 'torch>=1.6',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------