├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── coca.png ├── coca_pytorch ├── __init__.py └── coca_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## CoCa - Pytorch 4 | 5 | Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder. 6 | 7 | This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards) 8 | 9 | Update: CoCa has been trained by the good folks over at OpenClip 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install coca-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | First install the `vit-pytorch` for the image encoder, which needs to be pretrained 20 | 21 | ```bash 22 | $ pip install vit-pytorch>=0.40.2 23 | ``` 24 | 25 | Then 26 | 27 | ```python 28 | import torch 29 | 30 | # import vision transformer 31 | 32 | from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT 33 | from vit_pytorch.extractor import Extractor 34 | 35 | vit = SimpleViT( 36 | image_size = 256, 37 | patch_size = 32, 38 | num_classes = 1000, 39 | dim = 1024, 40 | depth = 6, 41 | heads = 16, 42 | mlp_dim = 2048, 43 | patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794 44 | ) 45 | 46 | vit = Extractor(vit, return_embeddings_only = True, detach = False) 47 | 48 | # extractor will enable it so the vision transformer returns its embeddings 49 | 50 | # import CoCa and instantiate it 51 | 52 | from coca_pytorch.coca_pytorch import CoCa 53 | 54 | coca = CoCa( 55 | dim = 512, # model dimension 56 | img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim) 57 | image_dim = 1024, # image embedding dimension, if not the same as model dimensions 58 | num_tokens = 20000, # number of text tokens 59 | unimodal_depth = 6, # depth of the unimodal transformer 60 | multimodal_depth = 6, # depth of the multimodal transformer 61 | dim_head = 64, # dimension per attention head 62 | heads = 8, # number of attention heads 63 | caption_loss_weight = 1., # weight on the autoregressive caption loss 64 | contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings 65 | ).cuda() 66 | 67 | # mock text and images 68 | 69 | text = torch.randint(0, 20000, (4, 512)).cuda() 70 | images = torch.randn(4, 3, 256, 256).cuda() 71 | 72 | # train by giving CoCa your text and images with `return_loss = True` 73 | 74 | loss = coca( 75 | text = text, 76 | images = images, 77 | return_loss = True # set this to True to get the full caption + contrastive loss 78 | ) 79 | 80 | loss.backward() 81 | 82 | # do the above for as much text and images... 83 | # then you can get the caption logits as so 84 | 85 | logits = coca( 86 | text = text, 87 | images = images 88 | ) # (4, 512, 20000) 89 | 90 | # and the CLIP-like text and image embeddings as 91 | 92 | text_embeds, image_embeds = coca( 93 | text = text, 94 | images = images, 95 | return_embeddings = True 96 | ) # (4, 512), (4, 512) 97 | ``` 98 | 99 | ## Citations 100 | 101 | ```bibtex 102 | @inproceedings{Yu2022CoCaCC, 103 | title = {CoCa: Contrastive Captioners are Image-Text Foundation Models}, 104 | author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu}, 105 | year = {2022} 106 | } 107 | ``` 108 | 109 | ```bibtex 110 | @inproceedings{Chowdhery2022PaLMSL, 111 | title = {PaLM: Scaling Language Modeling with Pathways}, 112 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel}, 113 | year = {2022} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /coca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/CoCa-pytorch/edee92c74e311ccfa4a0024412fd991c98aff5fd/coca.png -------------------------------------------------------------------------------- /coca_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from coca_pytorch.coca_pytorch import CoCa 2 | -------------------------------------------------------------------------------- /coca_pytorch/coca_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import einsum, nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | import torch.distributed as dist 6 | 7 | from einops import rearrange, repeat 8 | 9 | # helper functions 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | # distributed 18 | 19 | def pad_dim_to(t, length, dim = 0): 20 | pad_length = length - t.shape[dim] 21 | zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) 22 | return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)) 23 | 24 | def all_gather_variable_batch(t): 25 | device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size() 26 | 27 | size = torch.tensor(t.shape[0], device = device, dtype = torch.long) 28 | sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)] 29 | dist.all_gather(sizes, size) 30 | 31 | sizes = torch.stack(sizes) 32 | max_size = sizes.amax().item() 33 | 34 | padded_t = pad_dim_to(t, max_size, dim = 0) 35 | gathered_tensors = [torch.empty_like(padded_t, device = device, dtype = padded_t.dtype) for i in range(world_size)] 36 | dist.all_gather(gathered_tensors, padded_t) 37 | 38 | gathered_tensor = torch.cat(gathered_tensors) 39 | seq = torch.arange(max_size, device = device) 40 | 41 | mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1') 42 | mask = rearrange(mask, 'i j -> (i j)') 43 | 44 | gathered_tensor = gathered_tensor[mask] 45 | sizes = sizes.tolist() 46 | 47 | return gathered_tensor, sizes 48 | 49 | class AllGather(Function): 50 | @staticmethod 51 | def forward(ctx, x): 52 | assert dist.is_initialized() and dist.get_world_size() > 1 53 | x, batch_sizes = all_gather_variable_batch(x) 54 | ctx.batch_sizes = batch_sizes 55 | return x 56 | 57 | @staticmethod 58 | def backward(ctx, grads): 59 | batch_sizes, rank = ctx.batch_sizes, dist.get_rank() 60 | grads_by_rank = grads.split(batch_sizes, dim = 0) 61 | return grads_by_rank[rank] 62 | 63 | all_gather = AllGather.apply 64 | 65 | 66 | # normalization 67 | # they use layernorm without bias, something that pytorch does not offer 68 | 69 | 70 | class LayerNorm(nn.Module): 71 | def __init__(self, dim): 72 | super().__init__() 73 | self.gamma = nn.Parameter(torch.ones(dim)) 74 | self.register_buffer("beta", torch.zeros(dim)) 75 | 76 | def forward(self, x): 77 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 78 | 79 | # residual 80 | 81 | 82 | class Residual(nn.Module): 83 | def __init__(self, fn): 84 | super().__init__() 85 | self.fn = fn 86 | 87 | def forward(self, x, *args, **kwargs): 88 | return self.fn(x, *args, **kwargs) + x 89 | 90 | # to latents 91 | 92 | 93 | class EmbedToLatents(nn.Module): 94 | def __init__(self, dim, dim_latents): 95 | super().__init__() 96 | self.to_latents = nn.Linear(dim, dim_latents, bias=False) 97 | 98 | def forward(self, x): 99 | latents = self.to_latents(x) 100 | return F.normalize(latents, dim=-1) 101 | 102 | # rotary positional embedding 103 | # https://arxiv.org/abs/2104.09864 104 | 105 | 106 | class RotaryEmbedding(nn.Module): 107 | def __init__(self, dim): 108 | super().__init__() 109 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 110 | self.register_buffer("inv_freq", inv_freq) 111 | 112 | def forward(self, max_seq_len, *, device): 113 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) 114 | freqs = einsum("i , j -> i j", seq, self.inv_freq) 115 | return torch.cat((freqs, freqs), dim=-1) 116 | 117 | 118 | def rotate_half(x): 119 | x = rearrange(x, "... (j d) -> ... j d", j=2) 120 | x1, x2 = x.unbind(dim=-2) 121 | return torch.cat((-x2, x1), dim=-1) 122 | 123 | 124 | def apply_rotary_pos_emb(pos, t): 125 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 126 | 127 | 128 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward 129 | # https://arxiv.org/abs/2002.05202 130 | 131 | 132 | class SwiGLU(nn.Module): 133 | def forward(self, x): 134 | x, gate = x.chunk(2, dim=-1) 135 | return F.silu(gate) * x 136 | 137 | 138 | # parallel attention and feedforward with residual 139 | # discovered by Wang et al + EleutherAI from GPT-J fame 140 | 141 | 142 | class ParallelTransformerBlock(nn.Module): 143 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): 144 | super().__init__() 145 | self.norm = LayerNorm(dim) 146 | 147 | attn_inner_dim = dim_head * heads 148 | ff_inner_dim = dim * ff_mult 149 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) 150 | 151 | self.heads = heads 152 | self.scale = dim_head**-0.5 153 | self.rotary_emb = RotaryEmbedding(dim_head) 154 | 155 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) 156 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) 157 | 158 | self.ff_out = nn.Sequential( 159 | SwiGLU(), 160 | nn.Linear(ff_inner_dim, dim, bias=False) 161 | ) 162 | 163 | # for caching causal mask and rotary embeddings 164 | 165 | self.mask = None 166 | self.pos_emb = None 167 | 168 | def get_mask(self, n, device): 169 | if self.mask is not None and self.mask.shape[-1] >= n: 170 | return self.mask[:n, :n].to(device) 171 | 172 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 173 | self.mask = mask 174 | return mask 175 | 176 | def get_rotary_embedding(self, n, device): 177 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: 178 | return self.pos_emb[:n].to(device) 179 | 180 | pos_emb = self.rotary_emb(n, device=device) 181 | self.pos_emb = pos_emb 182 | return pos_emb 183 | 184 | def forward(self, x, attn_mask=None): 185 | """ 186 | einstein notation 187 | b - batch 188 | h - heads 189 | n, i, j - sequence length (base sequence length, source, target) 190 | d - feature dimension 191 | """ 192 | 193 | n, device, h = x.shape[1], x.device, self.heads 194 | 195 | # pre layernorm 196 | 197 | x = self.norm(x) 198 | 199 | # attention queries, keys, values, and feedforward inner 200 | 201 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) 202 | 203 | # split heads 204 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper 205 | # they found no performance loss past a certain scale, and more efficient decoding obviously 206 | # https://arxiv.org/abs/1911.02150 207 | 208 | q = rearrange(q, "b n (h d) -> b h n d", h=h) 209 | 210 | # rotary embeddings 211 | 212 | positions = self.get_rotary_embedding(n, device) 213 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) 214 | 215 | # scale 216 | 217 | q = q * self.scale 218 | 219 | # similarity 220 | 221 | sim = einsum("b h i d, b j d -> b h i j", q, k) 222 | 223 | # causal mask 224 | 225 | causal_mask = self.get_mask(n, device) 226 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 227 | 228 | # extra attention mask - for masking out attention from text CLS token to padding 229 | 230 | if exists(attn_mask): 231 | attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') 232 | sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) 233 | 234 | # attention 235 | 236 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 237 | attn = sim.softmax(dim=-1) 238 | 239 | # aggregate values 240 | 241 | out = einsum("b h i j, b j d -> b h i d", attn, v) 242 | 243 | # merge heads 244 | 245 | out = rearrange(out, "b h n d -> b n (h d)") 246 | return self.attn_out(out) + self.ff_out(ff) 247 | 248 | # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward 249 | 250 | class CrossAttention(nn.Module): 251 | def __init__( 252 | self, 253 | dim, 254 | *, 255 | context_dim=None, 256 | dim_head=64, 257 | heads=8, 258 | parallel_ff=False, 259 | ff_mult=4, 260 | norm_context=False 261 | ): 262 | super().__init__() 263 | self.heads = heads 264 | self.scale = dim_head ** -0.5 265 | inner_dim = heads * dim_head 266 | context_dim = default(context_dim, dim) 267 | 268 | self.norm = LayerNorm(dim) 269 | self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() 270 | 271 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 272 | self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) 273 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 274 | 275 | # whether to have parallel feedforward 276 | 277 | ff_inner_dim = ff_mult * dim 278 | 279 | self.ff = nn.Sequential( 280 | nn.Linear(dim, ff_inner_dim * 2, bias=False), 281 | SwiGLU(), 282 | nn.Linear(ff_inner_dim, dim, bias=False) 283 | ) if parallel_ff else None 284 | 285 | def forward(self, x, context): 286 | """ 287 | einstein notation 288 | b - batch 289 | h - heads 290 | n, i, j - sequence length (base sequence length, source, target) 291 | d - feature dimension 292 | """ 293 | 294 | # pre-layernorm, for queries and context 295 | 296 | x = self.norm(x) 297 | context = self.context_norm(context) 298 | 299 | # get queries 300 | 301 | q = self.to_q(x) 302 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 303 | 304 | # scale 305 | 306 | q = q * self.scale 307 | 308 | # get key / values 309 | 310 | k, v = self.to_kv(context).chunk(2, dim=-1) 311 | 312 | # query / key similarity 313 | 314 | sim = einsum('b h i d, b j d -> b h i j', q, k) 315 | 316 | # attention 317 | 318 | sim = sim - sim.amax(dim=-1, keepdim=True) 319 | attn = sim.softmax(dim=-1) 320 | 321 | # aggregate 322 | 323 | out = einsum('b h i j, b j d -> b h i d', attn, v) 324 | 325 | # merge and combine heads 326 | 327 | out = rearrange(out, 'b h n d -> b n (h d)') 328 | out = self.to_out(out) 329 | 330 | # add parallel feedforward (for multimodal layers) 331 | 332 | if exists(self.ff): 333 | out = out + self.ff(x) 334 | 335 | return out 336 | 337 | # transformer 338 | 339 | 340 | class CoCa(nn.Module): 341 | def __init__( 342 | self, 343 | *, 344 | dim, 345 | num_tokens, 346 | unimodal_depth, 347 | multimodal_depth, 348 | dim_latents = None, 349 | image_dim = None, 350 | num_img_queries=256, 351 | dim_head=64, 352 | heads=8, 353 | ff_mult=4, 354 | img_encoder=None, 355 | caption_loss_weight=1., 356 | contrastive_loss_weight=1., 357 | pad_id=0 358 | ): 359 | super().__init__() 360 | self.dim = dim 361 | 362 | self.pad_id = pad_id 363 | self.caption_loss_weight = caption_loss_weight 364 | self.contrastive_loss_weight = contrastive_loss_weight 365 | 366 | # token embeddings 367 | 368 | self.token_emb = nn.Embedding(num_tokens, dim) 369 | self.text_cls_token = nn.Parameter(torch.randn(dim)) 370 | 371 | # image encoder 372 | 373 | self.img_encoder = img_encoder 374 | 375 | # attention pooling for image tokens 376 | 377 | self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning 378 | self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads, norm_context=True) 379 | 380 | self.img_attn_pool_norm = LayerNorm(dim) 381 | self.text_cls_norm = LayerNorm(dim) 382 | 383 | # to latents 384 | 385 | dim_latents = default(dim_latents, dim) 386 | self.img_to_latents = EmbedToLatents(dim, dim_latents) 387 | self.text_to_latents = EmbedToLatents(dim, dim_latents) 388 | 389 | # contrastive learning temperature 390 | 391 | self.temperature = nn.Parameter(torch.Tensor([1.])) 392 | 393 | # unimodal layers 394 | 395 | self.unimodal_layers = nn.ModuleList([]) 396 | for ind in range(unimodal_depth): 397 | self.unimodal_layers.append( 398 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), 399 | ) 400 | 401 | # multimodal layers 402 | 403 | self.multimodal_layers = nn.ModuleList([]) 404 | for ind in range(multimodal_depth): 405 | self.multimodal_layers.append(nn.ModuleList([ 406 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), 407 | Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)) 408 | ])) 409 | 410 | # to logits 411 | 412 | self.to_logits = nn.Sequential( 413 | LayerNorm(dim), 414 | nn.Linear(dim, num_tokens, bias=False) 415 | ) 416 | 417 | # they used embedding weight tied projection out to logits, not common, but works 418 | self.to_logits[-1].weight = self.token_emb.weight 419 | nn.init.normal_(self.token_emb.weight, std=0.02) 420 | 421 | # whether in data parallel setting 422 | self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1 423 | 424 | def embed_text(self, text): 425 | batch, device = text.shape[0], text.device 426 | 427 | seq = text.shape[1] 428 | 429 | text_tokens = self.token_emb(text) 430 | 431 | # append text cls tokens 432 | 433 | text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch) 434 | text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) 435 | 436 | # create specific mask for text cls token at the end 437 | # to prevent it from attending to padding 438 | 439 | cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j') 440 | attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) 441 | 442 | # go through unimodal layers 443 | 444 | for attn_ff in self.unimodal_layers: 445 | text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) 446 | 447 | # get text cls token 448 | 449 | text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] 450 | text_embeds = self.text_cls_norm(text_cls_tokens) 451 | return text_embeds, text_tokens 452 | 453 | def embed_image(self, images=None, image_tokens=None): 454 | # encode images into embeddings 455 | # with the img_encoder passed in at init 456 | # it can also accept precomputed image tokens 457 | 458 | assert not (exists(images) and exists(image_tokens)) 459 | 460 | if exists(images): 461 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding' 462 | image_tokens = self.img_encoder(images) 463 | 464 | # attention pool image tokens 465 | 466 | img_queries = repeat(self.img_queries, 'n d -> b n d', b=image_tokens.shape[0]) 467 | img_queries = self.img_attn_pool(img_queries, image_tokens) 468 | img_queries = self.img_attn_pool_norm(img_queries) 469 | 470 | return img_queries[:, 0], img_queries[:, 1:] 471 | 472 | def forward( 473 | self, 474 | text, 475 | images=None, 476 | image_tokens=None, 477 | labels=None, 478 | return_loss=False, 479 | return_embeddings=False 480 | ): 481 | batch, device = text.shape[0], text.device 482 | 483 | if return_loss and not exists(labels): 484 | text, labels = text[:, :-1], text[:, 1:] 485 | 486 | text_embeds, text_tokens = self.embed_text(text) 487 | 488 | image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens) 489 | 490 | # return embeddings if that is what the researcher wants 491 | 492 | if return_embeddings: 493 | return text_embeds, image_embeds 494 | 495 | # go through multimodal layers 496 | 497 | for attn_ff, cross_attn in self.multimodal_layers: 498 | text_tokens = attn_ff(text_tokens) 499 | text_tokens = cross_attn(text_tokens, image_tokens) 500 | 501 | logits = self.to_logits(text_tokens) 502 | 503 | if not return_loss: 504 | return logits 505 | 506 | # shorthand 507 | 508 | ce = F.cross_entropy 509 | 510 | # calculate caption loss (cross entropy loss) 511 | 512 | logits = rearrange(logits, 'b n c -> b c n') 513 | caption_loss = ce(logits, labels, ignore_index=self.pad_id) 514 | caption_loss = caption_loss * self.caption_loss_weight 515 | 516 | # embedding to latents 517 | 518 | text_latents = self.text_to_latents(text_embeds) 519 | image_latents = self.img_to_latents(image_embeds) 520 | 521 | # maybe distributed all gather 522 | 523 | if self.is_distributed: 524 | latents = torch.stack((text_latents, image_latents), dim = 1) 525 | latents = all_gather(latents) 526 | text_latents, image_latents = latents.unbind(dim = 1) 527 | 528 | # calculate contrastive loss 529 | 530 | sim = einsum('i d, j d -> i j', text_latents, image_latents) 531 | sim = sim * self.temperature.exp() 532 | contrastive_labels = torch.arange(batch, device=device) 533 | 534 | contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5 535 | contrastive_loss = contrastive_loss * self.contrastive_loss_weight 536 | 537 | return caption_loss + contrastive_loss 538 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'CoCa-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.0', 7 | license='MIT', 8 | description = 'CoCa, Contrastive Captioners are Image-Text Foundation Models - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/CoCa-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'contrastive learning', 19 | 'multimodal' 20 | ], 21 | install_requires=[ 22 | 'einops>=0.4', 23 | 'torch>=1.6', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | --------------------------------------------------------------------------------