├── .github └── workflows │ ├── autoblack.yml │ └── sphinx.yml ├── LICENSE ├── README.md ├── clap ├── __init__.py ├── datasets.py ├── layers │ ├── __init__.py │ ├── attentions │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── cvt_attention.py │ │ └── talking_heads.py │ ├── feedforwards │ │ ├── __init__.py │ │ ├── ff.py │ │ └── leff.py │ ├── normalizations │ │ ├── __init__.py │ │ └── layerscale.py │ ├── position_embed.py │ ├── regularization │ │ ├── __init__.py │ │ └── stochastic_depth.py │ ├── squeeze_excite.py │ └── stems │ │ ├── __init__.py │ │ ├── image_to_token.py │ │ └── patch_embed.py ├── models.py └── trunks │ ├── __init__.py │ ├── cait.py │ ├── create_trunk.py │ ├── mlp_mixer.py │ ├── tnt.py │ ├── transformer.py │ └── vit.py ├── configs ├── model │ ├── audio │ │ ├── cait.yaml │ │ ├── mixer.yaml │ │ ├── tnt.yaml │ │ └── vit.yaml │ └── text │ │ └── transformer.yaml ├── optimizer │ └── standard.yaml ├── preprocessing │ └── dataset │ │ └── commonvoice.yaml └── training │ └── standard.yaml ├── preprocess.py ├── setup.py └── train.py /.github/workflows/autoblack.yml: -------------------------------------------------------------------------------- 1 | # GitHub Action that uses Black to reformat the Python code in an incoming pull request. 2 | # If all Python code in the pull request is compliant with Black then this Action does nothing. 3 | # Othewrwise, Black is run and its changes are committed back to the incoming pull request. 4 | # https://github.com/cclauss/autoblack 5 | 6 | name: autoblack 7 | on: [pull_request] 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Set up Python 3.7 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.7 17 | - name: Install Black 18 | run: pip install black 19 | - name: Run black --check . 20 | run: black --check . 21 | - name: If needed, commit black changes to the pull request 22 | if: failure() 23 | run: | 24 | black . 25 | git config --global user.name 'autoblack' 26 | git config --global user.email 'cclauss@users.noreply.github.com' 27 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY 28 | git checkout $GITHUB_HEAD_REF 29 | git commit -am "fixup: Format Python code with Black" 30 | git push 31 | -------------------------------------------------------------------------------- /.github/workflows/sphinx.yml: -------------------------------------------------------------------------------- 1 | name: Pages 2 | on: [push] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - name: Checkout 8 | uses: actions/checkout@master 9 | with: 10 | fetch-depth: 0 # otherwise, you will failed to push refs to dest repo 11 | - name: Build and Commit 12 | uses: sphinx-notes/pages@master 13 | - name: Push changes 14 | uses: ad-m/github-push-action@master 15 | with: 16 | github_token: ${{ secrets.GITHUB_TOKEN }} 17 | branch: gh-pages 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Charles Foster 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLAP 2 | 3 | 4 | Contrastive Language-Audio Pretraining 5 | 6 | 7 | The folks over at LAION have picked up the mantle, continuing this line of work [in a new repo](https://github.com/LAION-AI/CLAP/). 8 | 9 | 10 | ## Citations 11 | 12 | [OpenAI blog post "CLIP: Connecting Text and Images"](https://openai.com/blog/clip/) 13 | 14 | ```bibtex 15 | @article{radford2021learning, 16 | title={Learning transferable visual models from natural language supervision}, 17 | author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, 18 | journal={arXiv preprint arXiv:2103.00020}, 19 | year={2021} 20 | } 21 | ``` 22 | 23 | ```bibtex 24 | @article{jia2021scaling, 25 | title={Scaling Up Visual and Vision-Language Representation Learning With Noisy Text Supervision}, 26 | author={Jia, Chao and Yang, Yinfei and Xia, Ye and Chen, Yi-Ting and Parekh, Zarana and Pham, Hieu and Le, Quoc V and Sung, Yunhsuan and Li, Zhen and Duerig, Tom}, 27 | journal={arXiv preprint arXiv:2102.05918}, 28 | year={2021} 29 | } 30 | ``` 31 | 32 | Much of the code behind the various transformer configurations has been adapted from Niccolò Zanichelli's [repository of Flax vision transformer modules](https://github.com/NZ99/self-attention-experiments-vision). 33 | 34 | Citation block courtesy of MicPie's awesome parallel project ["Contrastive Language-Aminoacid Sequence Pretraining"](https://github.com/MicPie/clasp). 35 | -------------------------------------------------------------------------------- /clap/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import CLAP 2 | from .datasets import PairTextSpectrogramTFRecords 3 | from .trunks import Transformer, ViT, TNT, CaiT, MLPMixer 4 | -------------------------------------------------------------------------------- /clap/datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import tensorflow as tf 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | from itertools import cycle, islice, chain 7 | from einops import rearrange, repeat 8 | 9 | import torch.nn.functional as F 10 | 11 | 12 | class PairTextSpectrogramTFRecords(object): 13 | def __init__( 14 | self, 15 | local_or_gcs_path, 16 | batch_size, 17 | prefetch_size=0, 18 | mel_bins=80, 19 | max_audio_len=2048, 20 | max_text_len=256, 21 | ): 22 | self.mel_bins = mel_bins 23 | self.max_audio_len = max_audio_len 24 | self.max_text_len = max_text_len 25 | self.path = local_or_gcs_path 26 | self.batch_size = batch_size 27 | self.prefetch_size = prefetch_size 28 | self.mel_bins = mel_bins 29 | self.max_audio_len = max_audio_len 30 | self.max_text_len = max_text_len 31 | 32 | def files(self): 33 | return self.files 34 | 35 | def __iter__(self): 36 | files = tf.data.TFRecordDataset.list_files( 37 | self.path + "/*.tfrecord", shuffle=False 38 | ) 39 | dataset = tf.data.TFRecordDataset(files) 40 | dataset = dataset.map(self.deserialize_tf_record) 41 | dataset = dataset.padded_batch( 42 | self.batch_size, 43 | padded_shapes={ 44 | "audio": (self.max_audio_len, self.mel_bins), 45 | "text": (self.max_text_len), 46 | }, 47 | ) 48 | dataset = dataset.map(self.unsqueeze_trailing) 49 | dataset = dataset.prefetch(self.prefetch_size) 50 | dataset = dataset.as_numpy_iterator() 51 | 52 | return dataset 53 | 54 | def deserialize_tf_record(self, record): 55 | tfrecord_format = { 56 | "audio": tf.io.FixedLenSequenceFeature( 57 | (self.mel_bins,), dtype=tf.float32, allow_missing=True 58 | ), 59 | "text": tf.io.FixedLenSequenceFeature( 60 | [], dtype=tf.int64, allow_missing=True 61 | ), 62 | } 63 | 64 | features_tensor = tf.io.parse_single_example(record, tfrecord_format) 65 | return features_tensor 66 | 67 | def unsqueeze_trailing(self, record): 68 | record = { 69 | "audio": repeat(record["audio"], "... -> ... ()"), 70 | "text": record["text"], 71 | } 72 | return record 73 | 74 | @staticmethod 75 | def write(spectrograms, captions, fname="data.tfrecord"): 76 | tfrecord_writer = tf.io.TFRecordWriter(fname) 77 | for (spectrogram, caption) in tqdm(zip(spectrograms, captions)): 78 | example = tf.train.Example( 79 | features=tf.train.Features( 80 | feature={ 81 | "audio": tf.train.Feature( 82 | float_list=tf.train.FloatList(value=spectrogram.flatten()) 83 | ), 84 | "text": tf.train.Feature( 85 | int64_list=tf.train.Int64List( 86 | value=[*caption.encode("utf-8")] 87 | ) 88 | ), 89 | } 90 | ) 91 | ) 92 | tfrecord_writer.write(example.SerializeToString()) 93 | 94 | tfrecord_writer.close() 95 | 96 | 97 | def roundrobin(*iterables): 98 | "roundrobin('ABC', 'D', 'EF') --> A D E B F C" 99 | # Recipe credited to George Sakkis 100 | num_active = len(iterables) 101 | nexts = cycle(iter(it).__next__ for it in iterables) 102 | while num_active: 103 | try: 104 | for next in nexts: 105 | yield next() 106 | except StopIteration: 107 | # Remove the iterator we just exhausted from the cycle. 108 | num_active -= 1 109 | nexts = cycle(islice(nexts, num_active)) 110 | -------------------------------------------------------------------------------- /clap/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .feedforwards import FFBlock, LeFFBlock 2 | from .stems import Image2TokenBlock, PatchEmbedBlock 3 | from .squeeze_excite import SqueezeExciteBlock 4 | from .position_embed import ( 5 | AddAbsPosEmbed, 6 | RotaryPositionalEmbedding, 7 | FixedPositionalEmbedding, 8 | ) 9 | from .attentions import ( 10 | AttentionBlock, 11 | SelfAttentionBlock, 12 | CvTAttentionBlock, 13 | CvTSelfAttentionBlock, 14 | ) 15 | from .normalizations import LayerScaleBlock 16 | from .regularization import StochasticDepthBlock 17 | -------------------------------------------------------------------------------- /clap/layers/attentions/__init__.py: -------------------------------------------------------------------------------- 1 | from .talking_heads import TalkingHeadsBlock 2 | from .attention import AttentionBlock, SelfAttentionBlock 3 | from .cvt_attention import CvTAttentionBlock, CvTSelfAttentionBlock 4 | -------------------------------------------------------------------------------- /clap/layers/attentions/attention.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | from flax import linen as nn 5 | from jax import numpy as jnp 6 | 7 | from . import TalkingHeadsBlock 8 | from .. import RotaryPositionalEmbedding 9 | 10 | 11 | class AttentionBlock(nn.Module): 12 | num_heads: int 13 | head_ch: Optional[int] = None 14 | out_ch: Optional[int] = None 15 | talking_heads: bool = False 16 | rotary_qk: bool = False 17 | rotary_v: bool = False 18 | attn_dropout_rate: float = 0.0 19 | out_dropout_rate: float = 0.0 20 | use_bias: bool = False 21 | dtype: jnp.dtype = jnp.float32 22 | 23 | @nn.compact 24 | def __call__(self, inputs_q, inputs_kv, is_training: bool): 25 | assert inputs_q.ndim == inputs_kv.ndim == 3 26 | 27 | in_ch = inputs_q.shape[-1] 28 | assert in_ch % self.num_heads == 0 29 | head_ch = self.head_ch or int(in_ch / self.num_heads) 30 | out_ch = self.out_ch or in_ch 31 | 32 | dense = partial( 33 | nn.DenseGeneral, 34 | axis=-1, 35 | features=(self.num_heads, head_ch), 36 | use_bias=self.use_bias, 37 | dtype=self.dtype, 38 | ) 39 | 40 | query = dense(name="queries")(inputs_q) 41 | key = dense(name="keys")(inputs_kv) 42 | value = dense(name="values")(inputs_kv) 43 | 44 | if self.rotary_qk: 45 | query = RotaryPositionalEmbedding()(query) 46 | key = RotaryPositionalEmbedding()(key) 47 | if self.rotary_v: 48 | value = RotaryPositionalEmbedding()(value) 49 | 50 | query = query / jnp.sqrt(head_ch) 51 | 52 | attn_weights = jnp.einsum("... q h d, ... k h d -> ... h q k", query, key) 53 | 54 | if self.talking_heads: 55 | attn_weights = TalkingHeadsBlock(num_heads=self.num_heads)(attn_weights) 56 | 57 | attn_weights = nn.softmax(attn_weights) 58 | 59 | if self.talking_heads: 60 | attn_weights = TalkingHeadsBlock(num_heads=self.num_heads)(attn_weights) 61 | 62 | attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( 63 | attn_weights, deterministic=not is_training 64 | ) 65 | 66 | attn_scores = jnp.einsum( 67 | "... h q k, ... k h d -> ... q h d", attn_weights, value 68 | ) 69 | 70 | output = nn.DenseGeneral( 71 | features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype 72 | )(attn_scores) 73 | 74 | output = nn.Dropout(rate=self.out_dropout_rate)( 75 | output, deterministic=not is_training 76 | ) 77 | return output 78 | 79 | 80 | class SelfAttentionBlock(AttentionBlock): 81 | @nn.compact 82 | def __call__(self, inputs, is_training: bool): 83 | return super().__call__(inputs, inputs, is_training=is_training) 84 | -------------------------------------------------------------------------------- /clap/layers/attentions/cvt_attention.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, Tuple 3 | 4 | from flax import linen as nn 5 | from jax import numpy as jnp 6 | 7 | from einops import rearrange 8 | 9 | from . import TalkingHeadsBlock 10 | 11 | 12 | class ConvProjectionBlock(nn.Module): 13 | out_ch: int 14 | kernel_size: int = 3 15 | strides: int = 1 16 | use_bias: bool = True 17 | bn_momentum: float = 0.9 18 | bn_epsilon: float = 1e-5 19 | dtype: jnp.dtype = jnp.float32 20 | 21 | @nn.compact 22 | def __call__(self, inputs, is_training: bool): 23 | in_ch = inputs.shape[-1] 24 | 25 | conv = partial(nn.Conv, dtype=self.dtype) 26 | 27 | x = conv( 28 | features=in_ch, 29 | kernel_size=(self.kernel_size, self.kernel_size), 30 | strides=(self.strides, self.strides), 31 | padding="SAME", 32 | feature_group_count=in_ch, 33 | use_bias=False, 34 | )(inputs) 35 | x = nn.BatchNorm( 36 | use_running_average=not is_training, 37 | momentum=self.bn_momentum, 38 | epsilon=self.bn_epsilon, 39 | dtype=self.dtype, 40 | )(x) 41 | output = conv(features=self.out_ch, kernel_size=(1, 1), use_bias=self.use_bias)( 42 | x 43 | ) 44 | return output 45 | 46 | 47 | class CvTAttentionBlock(nn.Module): 48 | num_heads: int 49 | head_ch: Optional[int] = None 50 | out_ch: Optional[int] = None 51 | talking_heads: bool = False 52 | attn_dropout_rate: float = 0.0 53 | out_dropout_rate: float = 0.0 54 | kernel_size: int = 3 55 | strides: Tuple[int] = (1, 2, 2) 56 | use_bias: bool = False 57 | bn_momentum: float = 0.9 58 | bn_epsilon: float = 1e-5 59 | dtype: jnp.dtype = jnp.float32 60 | 61 | @nn.compact 62 | def __call__(self, inputs_q, inputs_kv, is_training: bool): 63 | assert inputs_q.ndim == 4 64 | assert inputs_kv.ndim == 4 65 | assert len(self.strides) == 3 66 | q_strides, k_strides, v_strides = self.strides 67 | 68 | in_ch = inputs_q.shape[-1] 69 | assert in_ch % self.num_heads == 0 70 | head_ch = self.head_ch or int(in_ch / self.num_heads) 71 | out_ch = self.out_ch or in_ch 72 | 73 | conv_proj = partial( 74 | ConvProjectionBlock, 75 | out_ch=self.num_heads * head_ch, 76 | kernel_size=self.kernel_size, 77 | use_bias=self.use_bias, 78 | bn_momentum=self.bn_momentum, 79 | bn_epsilon=self.bn_epsilon, 80 | dtype=self.dtype, 81 | ) 82 | 83 | query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training) 84 | key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training) 85 | value = conv_proj(strides=v_strides)(inputs_kv, is_training=is_training) 86 | 87 | query = rearrange(query, "b H W (h d) -> b (H W) h d", h=self.num_heads) 88 | key = rearrange(key, "b H W (h d) -> b (H W) h d", h=self.num_heads) 89 | value = rearrange(value, "b H W (h d) -> b (H W) h d", h=self.num_heads) 90 | 91 | query = query / jnp.sqrt(head_ch) 92 | 93 | attn_weights = jnp.einsum("... q h d, ... k h d -> ... h q k", query, key) 94 | 95 | if self.talking_heads: 96 | attn_weights = TalkingHeadsBlock(num_heads=self.num_heads)(attn_weights) 97 | 98 | attn_weights = nn.softmax(attn_weights) 99 | 100 | if self.talking_heads: 101 | attn_weights = TalkingHeadsBlock(num_heads=self.num_heads)(attn_weights) 102 | 103 | attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( 104 | attn_weights, deterministic=not is_training 105 | ) 106 | 107 | attn_scores = jnp.einsum( 108 | "... h q k, ... k h d -> ... q h d", attn_weights, value 109 | ) 110 | 111 | output = nn.DenseGeneral( 112 | features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype 113 | )(attn_scores) 114 | 115 | output = nn.Dropout(rate=self.out_dropout_rate)( 116 | output, deterministic=not is_training 117 | ) 118 | return output 119 | 120 | 121 | class CvTSelfAttentionBlock(CvTAttentionBlock): 122 | @nn.compact 123 | def __call__(self, inputs, is_training: bool): 124 | return super().__call__(inputs, inputs, is_training) 125 | -------------------------------------------------------------------------------- /clap/layers/attentions/talking_heads.py: -------------------------------------------------------------------------------- 1 | from flax import linen as nn 2 | from jax import numpy as jnp 3 | 4 | 5 | class TalkingHeadsBlock(nn.Module): 6 | num_heads: int 7 | 8 | @nn.compact 9 | def __call__(self, inputs): 10 | transform_shape = (self.num_heads, self.num_heads) 11 | transform = self.param( 12 | "talking_heads_transform", nn.initializers.orthogonal(), transform_shape 13 | ) 14 | output = jnp.einsum("h i, b h ... -> b i ...", transform, inputs) 15 | return output 16 | -------------------------------------------------------------------------------- /clap/layers/feedforwards/__init__.py: -------------------------------------------------------------------------------- 1 | from .ff import FFBlock 2 | from .leff import LeFFBlock 3 | -------------------------------------------------------------------------------- /clap/layers/feedforwards/ff.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable 3 | 4 | from flax import linen as nn 5 | from jax import numpy as jnp 6 | 7 | 8 | class FFBlock(nn.Module): 9 | expand_ratio: float = None 10 | hidden_ch: int = None 11 | dropout_rate: float = 0.0 12 | activation_fn: Callable = nn.activation.gelu 13 | dtype: jnp.dtype = jnp.float32 14 | 15 | @nn.compact 16 | def __call__(self, inputs, is_training: bool): 17 | in_ch = inputs.shape[-1] 18 | if self.expand_ratio is None: 19 | if self.hidden_ch is None: 20 | raise ValueError("Must provide one of expand_ratio or hidden_ch") 21 | hidden_ch = self.hidden_ch 22 | else: 23 | hidden_ch = max(1, int(self.expand_ratio * in_ch)) 24 | 25 | dense = partial(nn.Dense, use_bias=True, dtype=self.dtype) 26 | 27 | x = dense(features=hidden_ch)(inputs) 28 | x = self.activation_fn(x) 29 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) 30 | x = dense(features=in_ch)(x) 31 | output = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) 32 | return output 33 | -------------------------------------------------------------------------------- /clap/layers/feedforwards/leff.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable 3 | 4 | from jax import numpy as jnp 5 | from flax import linen as nn 6 | from einops import rearrange 7 | 8 | 9 | class LeFFBlock(nn.Module): 10 | expand_ratio: float = None 11 | hidden_ch: int = None 12 | kernel_size: int = 5 13 | activation_fn: Callable = nn.activation.gelu 14 | bn_momentum: float = 0.9 15 | bn_epsilon: float = 1e-5 16 | dtype: jnp.dtype = jnp.float32 17 | 18 | @nn.compact 19 | def __call__(self, inputs, is_training: bool): 20 | 21 | cls, tokens = inputs[:, 0], inputs[:, 1:] 22 | _, l, in_ch = tokens.shape 23 | if self.expand_ratio is None: 24 | if self.hidden_ch is None: 25 | raise ValueError("Must provide one of expand_ratio or hidden_ch") 26 | hidden_ch = self.hidden_ch 27 | else: 28 | hidden_ch = max(1, int(self.expand_ratio * in_ch)) 29 | 30 | dense = partial(nn.Dense, use_bias=True, dtype=self.dtype) 31 | 32 | batch_norm = partial( 33 | nn.BatchNorm, 34 | use_running_average=not is_training, 35 | momentum=self.bn_momentum, 36 | epsilon=self.bn_epsilon, 37 | dtype=self.dtype, 38 | ) 39 | 40 | x = dense(features=hidden_ch)(tokens) 41 | x = batch_norm()(x) 42 | x = self.activation_fn(x) 43 | 44 | spatial_ch = int(jnp.sqrt(l)) 45 | x = rearrange(x, "b (h w) c -> b h w c", h=spatial_ch, w=spatial_ch) 46 | 47 | x = nn.Conv( 48 | features=hidden_ch, 49 | kernel_size=(self.kernel_size, self.kernel_size), 50 | padding="SAME", 51 | dtype=self.dtype, 52 | )(x) 53 | x = batch_norm()(x) 54 | x = self.activation_fn(x) 55 | 56 | x = rearrange(x, "b h w c -> b (h w) c") 57 | 58 | x = dense(features=in_ch)(x) 59 | x = batch_norm()(x) 60 | x = self.activation_fn(x) 61 | 62 | cls_token = jnp.expand_dims(cls, axis=1) 63 | output = jnp.concatenate([cls_token, x], axis=1) 64 | return output 65 | -------------------------------------------------------------------------------- /clap/layers/normalizations/__init__.py: -------------------------------------------------------------------------------- 1 | from .layerscale import LayerScaleBlock 2 | -------------------------------------------------------------------------------- /clap/layers/normalizations/layerscale.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | def full(eps: float, dtype: jnp.dtype = jnp.float32): 6 | def init(key, shape, dtype=dtype): 7 | return jnp.full(shape, eps, dtype=dtype) 8 | 9 | return init 10 | 11 | 12 | class LayerScaleBlock(nn.Module): 13 | eps: float 14 | dtype: jnp.dtype = jnp.float32 15 | 16 | @nn.compact 17 | def __call__(self, inputs, *unused_args, **unused_kwargs): 18 | in_ch = inputs.shape[-1] 19 | scale = self.param("layerscale", full(self.eps, dtype=self.dtype), (in_ch,)) 20 | scale = jnp.asarray(scale, self.dtype) 21 | return inputs * scale 22 | -------------------------------------------------------------------------------- /clap/layers/position_embed.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from flax import linen as nn 4 | import jax.numpy as jnp 5 | from einops import rearrange, repeat 6 | 7 | 8 | def rotate_every_two(x): 9 | x1 = x[:, :, :, ::2] 10 | x2 = x[:, :, :, 1::2] 11 | 12 | x = jnp.stack((-x2, x1), axis=-1) 13 | 14 | return rearrange(x, "... d j -> ... (d j)") 15 | 16 | 17 | def apply_rotary_pos_emb(x, sincos): 18 | sin, cos = map(lambda t: repeat(t, "n d -> n (d j)", j=2)[None, :, None, :], sincos) 19 | return (x * cos) + (rotate_every_two(x) * sin) 20 | 21 | 22 | class FixedPositionalEmbedding(nn.Module): 23 | @nn.compact 24 | def __call__(self, inputs, seq_dim=1): 25 | dim = inputs.shape[-1] 26 | # Inputs shaped as [batch, sequence, heads, dimensions] 27 | intervals = jnp.arange(start=0, stop=dim, step=2, dtype=self.dtype) 28 | inv_freq = 1.0 / (10e4 ** intervals / dim) 29 | t = jnp.arange(inputs.shape[seq_dim], dtype=self.dtype) 30 | 31 | freqs = jnp.einsum("i , j -> i j", t, inv_freq) 32 | return jnp.sin(freqs), jnp.cos(freqs) 33 | 34 | 35 | class RotaryPositionalEmbedding(FixedPositionalEmbedding): 36 | dtype: jnp.dtype = jnp.float32 37 | 38 | @nn.compact 39 | def __call__(self, inputs, seq_dim=1): 40 | sin, cos = super(RotaryPositionalEmbedding, self).__call__(inputs, seq_dim) 41 | emb = apply_rotary_pos_emb(inputs, (sin, cos)) 42 | return emb 43 | 44 | 45 | class AddAbsPosEmbed(nn.Module): 46 | embed_init: Callable = nn.initializers.normal(stddev=0.02) 47 | 48 | @nn.compact 49 | def __call__(self, inputs): 50 | assert inputs.ndim == 3 51 | pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) 52 | pos_emb = self.param("pos_embed", self.embed_init, pos_emb_shape) 53 | output = inputs + pos_emb 54 | return output 55 | -------------------------------------------------------------------------------- /clap/layers/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | from .stochastic_depth import StochasticDepthBlock 2 | -------------------------------------------------------------------------------- /clap/layers/regularization/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as random 3 | import flax.linen as nn 4 | 5 | 6 | class StochasticDepthBlock(nn.Module): 7 | drop_rate: float 8 | scale_by_keep: bool = True 9 | 10 | @nn.compact 11 | def __call__(self, inputs, is_training: bool): 12 | 13 | if not is_training or self.drop_rate == 0.0: 14 | return inputs 15 | 16 | keep_prob = 1.0 - self.drop_rate 17 | rng = self.make_rng("stochastic_depth") 18 | 19 | b = inputs.shape[0] 20 | shape = [b,] + ( 21 | [ 22 | 1, 23 | ] 24 | * (inputs.ndim - 1) 25 | ) 26 | random_tensor = random.uniform(rng, shape, dtype=inputs.dtype) 27 | binary_tensor = jnp.floor(keep_prob + random_tensor) 28 | 29 | if self.scale_by_keep: 30 | x = inputs / keep_prob 31 | 32 | output = x * binary_tensor 33 | return output 34 | -------------------------------------------------------------------------------- /clap/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any 2 | 3 | from flax import linen as nn 4 | 5 | from jax.lax import Precision 6 | from jax import numpy as jnp 7 | 8 | from functools import partial 9 | 10 | Dtype = Any 11 | 12 | 13 | class SqueezeExciteBlock(nn.Module): 14 | 15 | se_ratio: float = None 16 | hidden_ch: int = None 17 | activation_fn: Callable = nn.activation.gelu 18 | dtype: jnp.dtype = jnp.float32 19 | 20 | @nn.compact 21 | def __call__(self, inputs): 22 | in_ch = inputs.shape[-1] 23 | if self.se_ratio is None: 24 | if self.hidden_ch is None: 25 | raise ValueError("Must provide one of se_ratio or hidden_ch") 26 | hidden_ch = self.hidden_ch 27 | else: 28 | hidden_ch = max(1, int(in_ch * self.se_ratio)) 29 | 30 | dense = partial(nn.Dense, use_bias=True, dtype=self.dtype) 31 | 32 | x = jnp.mean(inputs, axis=(1, 2), dtype=self.dtype, keepdims=True)(inputs) 33 | x = dense(features=hidden_ch)(x) 34 | x = self.activation_fn(x) 35 | x = dense(features=in_ch)(x) 36 | output = nn.sigmoid(x) * inputs 37 | return output 38 | -------------------------------------------------------------------------------- /clap/layers/stems/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_to_token import Image2TokenBlock 2 | from .patch_embed import PatchEmbedBlock 3 | -------------------------------------------------------------------------------- /clap/layers/stems/image_to_token.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from flax import linen as nn 4 | from jax import numpy as jnp 5 | from einops import rearrange 6 | 7 | 8 | class Image2TokenBlock(nn.Module): 9 | patch_shape: Tuple[int, int] 10 | num_ch: int 11 | conv_kernel_size: int 12 | conv_stride: int 13 | pool_window_size: int 14 | pool_stride: int 15 | embed_dim: int 16 | use_bias: bool = False 17 | bn_momentum: float = 0.9 18 | bn_epsilon: float = 1e-5 19 | dtype: jnp.dtype = jnp.float32 20 | 21 | @nn.compact 22 | def __call__(self, inputs, is_training: bool): 23 | x = nn.Conv( 24 | features=self.num_ch, 25 | use_bias=self.use_bias, 26 | kernel_size=(self.conv_kernel_size, self.conv_kernel_size), 27 | strides=(self.conv_stride, self.conv_stride), 28 | padding=[(self.patch_shape[0],) * 2, (self.patch_shape[1],) * 2], 29 | )(inputs) 30 | x = nn.BatchNorm( 31 | use_running_average=not is_training, 32 | momentum=self.bn_momentum, 33 | epsilon=self.bn_epsilon, 34 | dtype=self.dtype, 35 | )(x) 36 | x = nn.max_pool( 37 | inputs=x, 38 | window_shape=(self.pool_window_size,) * 2, 39 | strides=(self.pool_stride,) * 2, 40 | ) 41 | x = rearrange( 42 | x, 43 | "b (h ph) (w pw) c -> b (h w) (ph pw c)", 44 | ph=self.patch_shape[0], 45 | pw=self.patch_shape[1], 46 | ) 47 | 48 | output = nn.Dense( 49 | features=self.embed_dim, use_bias=self.use_bias, dtype=self.dtype 50 | )(x) 51 | return output 52 | -------------------------------------------------------------------------------- /clap/layers/stems/patch_embed.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from flax import linen as nn 4 | from jax import numpy as jnp 5 | from einops import rearrange 6 | 7 | 8 | class PatchEmbedBlock(nn.Module): 9 | 10 | patch_shape: Tuple[int] 11 | embed_dim: int 12 | use_bias: bool = False 13 | dtype: jnp.dtype = jnp.float32 14 | 15 | @nn.compact 16 | def __call__(self, inputs, *unused_args, **unused_kwargs): 17 | ph, pw = self.patch_shape 18 | 19 | x = rearrange(inputs, "b (h ph) (w pw) c -> b (h w) (ph pw c)", ph=ph, pw=pw) 20 | output = nn.Dense( 21 | features=self.embed_dim, use_bias=self.use_bias, dtype=self.dtype 22 | )(x) 23 | return output 24 | -------------------------------------------------------------------------------- /clap/models.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from typing import Any, Callable, Sequence, Optional 3 | from jax import lax, random, numpy as jnp, vmap, jit 4 | from jax.ops import index, index_update 5 | from .trunks import Transformer, ViT, TNT, CaiT, MLPMixer 6 | 7 | # einsum and einops 8 | 9 | from jax.numpy import einsum 10 | from einops import rearrange, repeat 11 | 12 | # flax 13 | 14 | import flax 15 | from flax.core import freeze, unfreeze 16 | from flax import linen as nn 17 | 18 | # constants 19 | 20 | LARGE_NEG_VALUE = -1e10 21 | 22 | # config 23 | 24 | from jax.config import config 25 | 26 | config.enable_omnistaging() # Linen requires enabling omnistaging 27 | 28 | # helpers 29 | 30 | 31 | def cross_entropy(logits, targets, axis=-1): 32 | logprobs = nn.log_softmax(logits, axis=axis) 33 | nll = jnp.take_along_axis(logprobs, jnp.expand_dims(targets, axis=axis), axis=axis) 34 | ce = -jnp.mean(nll) 35 | return ce 36 | 37 | 38 | # main class 39 | 40 | 41 | class CLAP(nn.Module): 42 | text_config: Any 43 | 44 | audio_config: Any 45 | 46 | temp_init: Callable = nn.initializers.zeros 47 | 48 | def setup(self): 49 | if self.text_config.kind == "transformer": 50 | self.text_encoder = Transformer( 51 | output_dim=self.text_config.projection_dim, 52 | num_layers=self.text_config.depth, 53 | num_heads=self.text_config.heads, 54 | embed_dim=self.text_config.dim, 55 | rotary_qk=self.text_config.rotary_qk, 56 | dtype=jnp.float32, 57 | ) 58 | else: 59 | raise NotImplementedError( 60 | "Only plain transformer encoders are currently supported for the text trunk." 61 | ) 62 | 63 | if self.audio_config.kind == "vit": 64 | self.audio_encoder = ViT( 65 | output_dim=self.audio_config.projection_dim, 66 | num_layers=self.audio_config.depth, 67 | num_heads=self.audio_config.heads, 68 | embed_dim=self.audio_config.dim, 69 | patch_shape=tuple(self.audio_config.patch_shape), 70 | rotary_qk=self.audio_config.rotary_qk, 71 | ) 72 | elif self.audio_config.kind == "tnt": 73 | self.audio_encoder = TNT( 74 | output_dim=self.audio_config.projection_dim, 75 | num_layers=self.audio_config.depth, 76 | inner_num_heads=self.audio_config.inner.heads, 77 | outer_num_heads=self.audio_config.outer.heads, 78 | inner_embed_dim=self.audio_config.inner.dim, 79 | outer_embed_dim=self.audio_config.outer.dim, 80 | patch_shape=tuple(self.audio_config.outer.patch_shape), 81 | transformed_patch_shape=tuple(self.audio_config.inner.patch_shape), 82 | rotary_qk=self.audio_config.rotary_qk, 83 | ) 84 | elif self.audio_config.kind == "cait": 85 | self.audio_encoder = CaiT( 86 | output_dim=self.audio_config.projection_dim, 87 | num_layers=self.audio_config.depth, 88 | num_layers_token_only=self.audio_config.token_only_depth, 89 | num_heads=self.audio_config.heads, 90 | embed_dim=self.audio_config.dim, 91 | patch_shape=self.audio_config.patch_shape, 92 | stoch_depth_rate=self.audio_config.stochastic_depth_rate, 93 | layerscale_eps=self.audio_config.layerscale_eps, 94 | rotary_qk=self.audio_config.rotary_qk, 95 | ) 96 | elif self.audio_config.kind == "mixer": 97 | self.audio_encoder = MLPMixer( 98 | output_dim=self.audio_config.projection_dim, 99 | num_layers=self.audio_config.depth, 100 | embed_dim=self.audio_config.dim, 101 | patch_shape=self.audio_config.patch_shape, 102 | ) 103 | else: 104 | raise NotImplementedError( 105 | "Only ViT, TNT, CaiT, and MLPMixer are supported audio trunks." 106 | ) 107 | 108 | self.text_tokenizer = nn.Embed( 109 | num_embeddings=self.text_config.vocab, features=self.text_config.dim 110 | ) 111 | 112 | self.temp = self.param("temperature", self.temp_init, tuple()) 113 | 114 | def encode_text(self, text, is_training): 115 | enc_text = self.text_encoder(text, is_training=is_training) 116 | return enc_text 117 | 118 | def encode_audio(self, audio, is_training): 119 | enc_audio = self.audio_encoder(audio, is_training=is_training) 120 | return enc_audio 121 | 122 | def __call__(self, text, audio, return_loss=True, is_training=False): 123 | b = text.shape[0] 124 | 125 | to_text_tokens = self.text_tokenizer 126 | 127 | text = to_text_tokens(text) 128 | 129 | enc_text = self.encode_text(text, is_training) 130 | enc_audio = self.encode_audio(audio, is_training) 131 | 132 | enc_text = enc_text / jnp.linalg.norm(enc_text, axis=-1, keepdims=True) 133 | enc_audio = enc_audio / jnp.linalg.norm(enc_audio, axis=-1, keepdims=True) 134 | 135 | sim = einsum("i d, j d -> i j", enc_text, enc_audio) * jnp.exp(self.temp) 136 | 137 | if not return_loss: 138 | return sim 139 | 140 | labels = jnp.arange(b) 141 | loss = ( 142 | cross_entropy(sim, labels, axis=0) + cross_entropy(sim, labels, axis=1) 143 | ) / 2 144 | return loss 145 | -------------------------------------------------------------------------------- /clap/trunks/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .vit import ViT 3 | from .cait import CaiT 4 | from .tnt import TNT 5 | from .mlp_mixer import MLPMixer 6 | -------------------------------------------------------------------------------- /clap/trunks/cait.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import jax.numpy as jnp 4 | import flax.linen as nn 5 | 6 | from ..layers import PatchEmbedBlock, AddAbsPosEmbed, AttentionBlock, SelfAttentionBlock 7 | from ..layers import LayerScaleBlock, StochasticDepthBlock, FFBlock 8 | 9 | 10 | class ClassSelfAttentionBlock(AttentionBlock): 11 | @nn.compact 12 | def __call__(self, inputs, is_training: bool): 13 | inputs_q = jnp.expand_dims(inputs[:, 0, :], axis=1) 14 | return super().__call__(inputs_q, inputs, is_training=is_training) 15 | 16 | 17 | class EncoderBlock(nn.Module): 18 | num_heads: int 19 | stoch_depth_rate: float 20 | layerscale_eps: float 21 | expand_ratio: float = 4 22 | attn_dropout_rate: float = 0.0 23 | dropout_rate: float = 0.0 24 | activation_fn: Callable = nn.activation.gelu 25 | rotary_qk: bool = False 26 | rotary_v: bool = False 27 | dtype: jnp.dtype = jnp.float32 28 | 29 | @nn.compact 30 | def __call__(self, inputs, is_training: bool): 31 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 32 | x = SelfAttentionBlock( 33 | num_heads=self.num_heads, 34 | talking_heads=True, 35 | attn_dropout_rate=self.attn_dropout_rate, 36 | out_dropout_rate=self.dropout_rate, 37 | rotary_qk=self.rotary_qk, 38 | rotary_v=self.rotary_v, 39 | dtype=self.dtype, 40 | )(x, is_training=is_training) 41 | x = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)( 42 | x, is_training=is_training 43 | ) 44 | x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( 45 | x, is_training=is_training 46 | ) 47 | x = x + inputs 48 | 49 | y = nn.LayerNorm(dtype=self.dtype)(x) 50 | y = FFBlock( 51 | expand_ratio=self.expand_ratio, 52 | dropout_rate=self.dropout_rate, 53 | activation_fn=self.activation_fn, 54 | dtype=self.dtype, 55 | )(y, is_training=is_training) 56 | y = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)( 57 | y, is_training=is_training 58 | ) 59 | y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( 60 | y, is_training=is_training 61 | ) 62 | 63 | output = x + y 64 | return output 65 | 66 | 67 | class Encoder(nn.Module): 68 | num_layers: int 69 | num_heads: int 70 | stoch_depth_rate: float 71 | layerscale_eps: float 72 | expand_ratio: float = 4 73 | attn_dropout_rate: float = 0.0 74 | dropout_rate: float = 0.0 75 | activation_fn: Callable = nn.activation.gelu 76 | rotary_qk: bool = False 77 | rotary_v: bool = False 78 | dtype: jnp.dtype = jnp.float32 79 | 80 | @nn.compact 81 | def __call__(self, inputs, is_training: bool): 82 | x = AddAbsPosEmbed()(inputs) 83 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) 84 | 85 | for _ in range(self.num_layers): 86 | x = EncoderBlock( 87 | num_heads=self.num_heads, 88 | expand_ratio=self.expand_ratio, 89 | attn_dropout_rate=self.attn_dropout_rate, 90 | dropout_rate=self.dropout_rate, 91 | stoch_depth_rate=self.stoch_depth_rate, 92 | layerscale_eps=self.layerscale_eps, 93 | activation_fn=self.activation_fn, 94 | rotary_qk=self.rotary_qk, 95 | rotary_v=self.rotary_v, 96 | dtype=self.dtype, 97 | )(x, is_training=is_training) 98 | 99 | output = x 100 | return output 101 | 102 | 103 | class CAEncoderBlock(nn.Module): 104 | num_heads: int 105 | stoch_depth_rate: float 106 | layerscale_eps: float 107 | expand_ratio: float = 4 108 | attn_dropout_rate: float = 0.0 109 | dropout_rate: float = 0.0 110 | activation_fn: Callable = nn.activation.gelu 111 | rotary_qk: bool = False 112 | rotary_v: bool = False 113 | dtype: jnp.dtype = jnp.float32 114 | 115 | @nn.compact 116 | def __call__(self, inputs, cls_token, is_training: bool): 117 | x = jnp.concatenate([cls_token, inputs], axis=1) 118 | x = nn.LayerNorm(dtype=self.dtype)(x) 119 | x = ClassSelfAttentionBlock( 120 | num_heads=self.num_heads, 121 | attn_dropout_rate=self.attn_dropout_rate, 122 | out_dropout_rate=self.dropout_rate, 123 | dtype=self.dtype, 124 | )(x, is_training=is_training) 125 | x = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)( 126 | x, is_training=is_training 127 | ) 128 | x = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( 129 | x, is_training=is_training 130 | ) 131 | cls_token = cls_token + x 132 | 133 | y = nn.LayerNorm(dtype=self.dtype)(cls_token) 134 | y = FFBlock( 135 | expand_ratio=self.expand_ratio, 136 | dropout_rate=self.dropout_rate, 137 | activation_fn=self.activation_fn, 138 | dtype=self.dtype, 139 | )(y, is_training=is_training) 140 | y = LayerScaleBlock(eps=self.layerscale_eps, dtype=self.dtype)( 141 | y, is_training=is_training 142 | ) 143 | y = StochasticDepthBlock(drop_rate=self.stoch_depth_rate)( 144 | y, is_training=is_training 145 | ) 146 | 147 | output = cls_token + y 148 | return output 149 | 150 | 151 | class CaiT(nn.Module): 152 | output_dim: int 153 | num_layers: int 154 | num_layers_token_only: int 155 | num_heads: int 156 | embed_dim: int 157 | patch_shape: Tuple[int] 158 | stoch_depth_rate: float 159 | layerscale_eps: float 160 | expand_ratio: float = 4 161 | attn_dropout_rate: float = 0.0 162 | dropout_rate: float = 0.0 163 | activation_fn: Callable = nn.activation.gelu 164 | rotary_qk: bool = False 165 | rotary_v: bool = False 166 | dtype: jnp.dtype = jnp.float32 167 | 168 | @nn.compact 169 | def __call__(self, inputs, is_training: bool): 170 | 171 | x = PatchEmbedBlock( 172 | patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype 173 | )(inputs) 174 | 175 | x = Encoder( 176 | num_layers=self.num_layers, 177 | num_heads=self.num_heads, 178 | expand_ratio=self.expand_ratio, 179 | attn_dropout_rate=self.attn_dropout_rate, 180 | dropout_rate=self.dropout_rate, 181 | stoch_depth_rate=self.stoch_depth_rate, 182 | layerscale_eps=self.layerscale_eps, 183 | rotary_qk=self.rotary_qk, 184 | rotary_v=self.rotary_v, 185 | activation_fn=self.activation_fn, 186 | )(x, is_training=is_training) 187 | 188 | b = x.shape[0] 189 | cls_shape = (1, 1, self.embed_dim) 190 | cls_token = self.param("cls", nn.initializers.zeros, cls_shape) 191 | cls_token = jnp.tile(cls_token, [b, 1, 1]) 192 | 193 | for _ in range(self.num_layers_token_only): 194 | cls_token = CAEncoderBlock( 195 | num_heads=self.num_heads, 196 | expand_ratio=self.expand_ratio, 197 | attn_dropout_rate=self.attn_dropout_rate, 198 | dropout_rate=self.dropout_rate, 199 | stoch_depth_rate=self.stoch_depth_rate, 200 | layerscale_eps=self.layerscale_eps, 201 | activation_fn=self.activation_fn, 202 | rotary_qk=self.rotary_qk, 203 | rotary_v=self.rotary_v, 204 | dtype=self.dtype, 205 | )(x, cls_token, is_training=is_training) 206 | 207 | x = jnp.concatenate([cls_token, x], axis=1) 208 | x = nn.LayerNorm(dtype=self.dtype)(x) 209 | 210 | cls_token = x[:, 0] 211 | output = nn.Dense( 212 | features=self.output_dim, 213 | use_bias=True, 214 | dtype=self.dtype, 215 | kernel_init=nn.initializers.orthogonal(), 216 | )(cls_token) 217 | return output 218 | -------------------------------------------------------------------------------- /clap/trunks/create_trunk.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | 3 | from . import Transformer, ViT, TNT, CaiT, MLPMixer 4 | 5 | 6 | def create_trunk( 7 | model_name: str, output_dim: int = 1000, dtype: jnp.dtype = jnp.float32 8 | ): 9 | 10 | if model_name == "txt_b": 11 | return Transformer( 12 | output_dim=output_dim, 13 | num_layers=12, 14 | num_heads=12, 15 | embed_dim=768, 16 | rotary_qk=True, 17 | dtype=dtype, 18 | ) 19 | elif model_name == "spec_b": 20 | return TNT( 21 | output_dim=output_dim, 22 | num_layers=12, 23 | inner_num_heads=4, 24 | outer_num_heads=6, 25 | inner_embed_dim=24, 26 | outer_embed_dim=384, 27 | patch_shape=(2, 80), 28 | transformed_patch_shape=(1, 80), 29 | rotary_qk=True, 30 | ) 31 | elif model_name == "spec_c": 32 | return ViT( 33 | output_dim=output_dim, 34 | num_layers=12, 35 | num_heads=12, 36 | embed_dim=384, 37 | patch_shape=(2, 80), 38 | rotary_qk=True, 39 | ) 40 | elif model_name == "vit_b_patch32": 41 | return ViT( 42 | output_dim=output_dim, 43 | num_layers=12, 44 | num_heads=12, 45 | embed_dim=768, 46 | patch_shape=(32, 32), 47 | dtype=dtype, 48 | ) 49 | elif model_name == "vit_b_patch16": 50 | return ViT( 51 | output_dim=output_dim, 52 | num_layers=12, 53 | num_heads=12, 54 | embed_dim=768, 55 | patch_shape=(16, 16), 56 | dtype=dtype, 57 | ) 58 | elif model_name == "vit_l_patch32": 59 | return ViT( 60 | output_dim=output_dim, 61 | num_layers=24, 62 | num_heads=16, 63 | embed_dim=1024, 64 | patch_shape=(32, 32), 65 | dtype=dtype, 66 | ) 67 | elif model_name == "vit_l_patch16": 68 | return ViT( 69 | output_dim=output_dim, 70 | num_layers=24, 71 | num_heads=16, 72 | embed_dim=1024, 73 | patch_shape=(16, 16), 74 | dtype=dtype, 75 | ) 76 | elif model_name == "tnt_s_patch16": 77 | return TNT( 78 | output_dim=output_dim, 79 | num_layers=12, 80 | inner_num_heads=4, 81 | outer_num_heads=10, 82 | inner_embed_dim=40, 83 | outer_embed_dim=640, 84 | patch_shape=(16, 16), 85 | transformed_patch_shape=(4, 4), 86 | ) 87 | elif model_name == "tnt_b_patch16": 88 | return TNT( 89 | output_dim=output_dim, 90 | num_layers=12, 91 | inner_num_heads=4, 92 | outer_num_heads=6, 93 | inner_embed_dim=24, 94 | outer_embed_dim=384, 95 | patch_shape=(16, 16), 96 | transformed_patch_shape=(4, 4), 97 | ) 98 | elif model_name == "cait_xxs_24": 99 | return CaiT( 100 | output_dim=output_dim, 101 | num_layers=24, 102 | num_layers_token_only=2, 103 | num_heads=4, 104 | embed_dim=192, 105 | patch_shape=(16, 16), 106 | stoch_depth_rate=0.05, 107 | layerscale_eps=1e-5, 108 | ) 109 | elif model_name == "cait_xxs_36": 110 | return CaiT( 111 | output_dim=output_dim, 112 | num_layers=36, 113 | num_layers_token_only=2, 114 | num_heads=4, 115 | embed_dim=192, 116 | patch_shape=(16, 16), 117 | stoch_depth_rate=0.1, 118 | layerscale_eps=1e-6, 119 | ) 120 | elif model_name == "cait_xs_24": 121 | return CaiT( 122 | output_dim=output_dim, 123 | num_layers=24, 124 | num_layers_token_only=2, 125 | num_heads=6, 126 | embed_dim=288, 127 | patch_shape=(16, 16), 128 | stoch_depth_rate=0.05, 129 | layerscale_eps=1e-5, 130 | ) 131 | elif model_name == "cait_xs_36": 132 | return CaiT( 133 | output_dim=output_dim, 134 | num_layers=36, 135 | num_layers_token_only=2, 136 | num_heads=6, 137 | embed_dim=288, 138 | patch_shape=(16, 16), 139 | stoch_depth_rate=0.1, 140 | layerscale_eps=1e-6, 141 | ) 142 | elif model_name == "cait_s_24": 143 | return CaiT( 144 | output_dim=output_dim, 145 | num_layers=24, 146 | num_layers_token_only=2, 147 | num_heads=8, 148 | embed_dim=384, 149 | patch_shape=(16, 16), 150 | stoch_depth_rate=0.1, 151 | layerscale_eps=1e-6, 152 | ) 153 | elif model_name == "cait_s_36": 154 | return CaiT( 155 | output_dim=output_dim, 156 | num_layers=36, 157 | num_layers_token_only=2, 158 | num_heads=8, 159 | embed_dim=384, 160 | patch_shape=(16, 16), 161 | stoch_depth_rate=0.2, 162 | layerscale_eps=1e-6, 163 | ) 164 | elif model_name == "cait_s_48": 165 | return CaiT( 166 | output_dim=output_dim, 167 | num_layers=48, 168 | num_layers_token_only=2, 169 | num_heads=8, 170 | embed_dim=384, 171 | patch_shape=(16, 16), 172 | stoch_depth_rate=0.3, 173 | layerscale_eps=1e-6, 174 | ) 175 | elif model_name == "cait_m_24": 176 | return CaiT( 177 | output_dim=output_dim, 178 | num_layers=24, 179 | num_layers_token_only=2, 180 | num_heads=16, 181 | embed_dim=768, 182 | patch_shape=(16, 16), 183 | stoch_depth_rate=0.2, 184 | layerscale_eps=1e-5, 185 | ) 186 | elif model_name == "cait_m_36": 187 | return CaiT( 188 | output_dim=output_dim, 189 | num_layers=36, 190 | num_layers_token_only=2, 191 | num_heads=16, 192 | embed_dim=768, 193 | patch_shape=(16, 16), 194 | stoch_depth_rate=0.3, 195 | layerscale_eps=1e-6, 196 | ) 197 | elif model_name == "cait_m_48": 198 | return CaiT( 199 | output_dim=output_dim, 200 | num_layers=48, 201 | num_layers_token_only=2, 202 | num_heads=16, 203 | embed_dim=768, 204 | patch_shape=(16, 16), 205 | stoch_depth_rate=0.4, 206 | layerscale_eps=1e-6, 207 | ) 208 | elif model_name == "mixer_s_patch32": 209 | return MLPMixer( 210 | output_dim=output_dim, num_layers=8, embed_dim=512, patch_shape=(32, 32) 211 | ) 212 | elif model_name == "mixer_s_patch16": 213 | return MLPMixer( 214 | output_dim=output_dim, num_layers=8, embed_dim=512, patch_shape=(16, 16) 215 | ) 216 | elif model_name == "mixer_b_patch32": 217 | return MLPMixer( 218 | output_dim=output_dim, num_layers=12, embed_dim=768, patch_shape=(32, 32) 219 | ) 220 | elif model_name == "mixer_s_patch32": 221 | return MLPMixer( 222 | output_dim=output_dim, num_layers=12, embed_dim=768, patch_shape=(16, 16) 223 | ) 224 | elif model_name == "mixer_l_patch32": 225 | return MLPMixer( 226 | output_dim=output_dim, num_layers=24, embed_dim=1024, patch_shape=(32, 32) 227 | ) 228 | elif model_name == "mixer_l_patch16": 229 | return MLPMixer( 230 | output_dim=output_dim, num_layers=32, embed_dim=1024, patch_shape=(16, 16) 231 | ) 232 | else: 233 | raise RuntimeError("Model not found.") 234 | -------------------------------------------------------------------------------- /clap/trunks/mlp_mixer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | from jax import numpy as jnp 4 | from flax import linen as nn 5 | from einops import rearrange 6 | 7 | from ..layers import PatchEmbedBlock, FFBlock 8 | 9 | 10 | class MixerBlock(nn.Module): 11 | tokens_expand_ratio: float 12 | channels_expand_ratio: float 13 | activation_fn: Callable = nn.activation.gelu 14 | dtype: jnp.dtype = jnp.float32 15 | 16 | @nn.compact 17 | def __call__(self, inputs, is_training: bool): 18 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 19 | x = rearrange(x, "... l d -> ... d l") 20 | x = FFBlock( 21 | expand_ratio=self.tokens_expand_ratio, 22 | activation_fn=self.activation_fn, 23 | dtype=self.dtype, 24 | )(x, is_training=is_training) 25 | x = rearrange(x, "... d l -> ... l d") 26 | x = x + inputs 27 | 28 | y = nn.LayerNorm(dtype=self.dtype)(x) 29 | y = FFBlock( 30 | expand_ratio=self.channels_expand_ratio, 31 | activation_fn=self.activation_fn, 32 | dtype=self.dtype, 33 | )(y, is_training=is_training) 34 | output = x + y 35 | return output 36 | 37 | 38 | class MLPMixer(nn.Module): 39 | output_dim: int 40 | num_layers: int 41 | embed_dim: int 42 | patch_shape: Tuple[int] 43 | tokens_expand_ratio: float = 0.5 44 | channels_expand_ratio: float = 4 45 | activation_fn: Callable = nn.activation.gelu 46 | dtype: jnp.dtype = jnp.float32 47 | 48 | @nn.compact 49 | def __call__(self, inputs, is_training: bool): 50 | x = PatchEmbedBlock( 51 | patch_shape=self.patch_shape, 52 | embed_dim=self.embed_dim, 53 | use_bias=True, 54 | dtype=self.dtype, 55 | )(inputs) 56 | 57 | for _ in range(self.num_layers): 58 | x = MixerBlock( 59 | tokens_expand_ratio=self.tokens_expand_ratio, 60 | channels_expand_ratio=self.channels_expand_ratio, 61 | activation_fn=self.activation_fn, 62 | dtype=self.dtype, 63 | )(x, is_training=is_training) 64 | 65 | x = nn.LayerNorm(dtype=self.dtype)(x) 66 | x = jnp.mean(x, axis=1) 67 | output = nn.Dense(features=self.output_dim, dtype=self.dtype)(x) 68 | return output 69 | -------------------------------------------------------------------------------- /clap/trunks/tnt.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Callable 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | from einops import rearrange 6 | 7 | from ..layers import AddAbsPosEmbed, SelfAttentionBlock, FFBlock, PatchEmbedBlock 8 | 9 | 10 | class PixelEmbedBlock(nn.Module): 11 | patch_shape: Tuple[int] 12 | transformed_patch_shape: Tuple[int] 13 | embed_dim: int 14 | use_bias: bool = True 15 | dtype: jnp.dtype = jnp.float32 16 | 17 | @nn.compact 18 | def __call__(self, inputs): 19 | assert self.patch_shape[0] % self.transformed_patch_shape[0] == 0 20 | assert self.patch_shape[1] % self.transformed_patch_shape[1] == 0 21 | 22 | x = rearrange( 23 | inputs, 24 | "b (h p1) (w p2) c -> (b h w) p1 p2 c", 25 | p1=self.patch_shape[0], 26 | p2=self.patch_shape[1], 27 | ) 28 | x = rearrange( 29 | x, 30 | "n (p1 t1) (p2 t2) c -> n (p1 p2) (c t1 t2)", 31 | t1=self.transformed_patch_shape[0], 32 | t2=self.transformed_patch_shape[1], 33 | ) 34 | output = nn.Dense(self.embed_dim, use_bias=self.use_bias, dtype=self.dtype)(x) 35 | return output 36 | 37 | 38 | class Inner2OuterBlock(nn.Module): 39 | out_ch: Optional[int] = None 40 | dtype: jnp.dtype = jnp.float32 41 | 42 | @nn.compact 43 | def __call__(self, patch_inputs, pixel_inputs): 44 | b = patch_inputs.shape[0] 45 | out_ch = self.out_ch or patch_inputs.shape[-1] 46 | 47 | x = rearrange(pixel_inputs, "... n d -> ... (n d)") 48 | x = nn.Dense(features=out_ch, dtype=self.dtype)(x) 49 | x = rearrange(x, "(b l) d -> b l d", b=b) 50 | x = jnp.pad(x, ((0, 0), (1, 0), (0, 0))) 51 | output = x + patch_inputs 52 | return output 53 | 54 | 55 | class EncoderBlock(nn.Module): 56 | inner_num_heads: int 57 | outer_num_heads: int 58 | inner_expand_ratio: float = 4 59 | outer_expand_ratio: float = 4 60 | attn_dropout_rate: float = 0.0 61 | dropout_rate: float = 0.0 62 | activation_fn: Callable = nn.activation.gelu 63 | rotary_qk: bool = False 64 | rotary_v: bool = False 65 | dtype: jnp.dtype = jnp.float32 66 | 67 | @nn.compact 68 | def __call__(self, patch_inputs, pixel_inputs, is_training: bool): 69 | inner_x = nn.LayerNorm(dtype=self.dtype)(pixel_inputs) 70 | inner_x = SelfAttentionBlock( 71 | num_heads=self.inner_num_heads, 72 | attn_dropout_rate=self.attn_dropout_rate, 73 | out_dropout_rate=self.dropout_rate, 74 | rotary_qk=self.rotary_qk, 75 | rotary_v=self.rotary_v, 76 | dtype=self.dtype, 77 | )(inner_x, is_training=is_training) 78 | inner_x = inner_x + pixel_inputs 79 | inner_y = nn.LayerNorm(dtype=self.dtype)(inner_x) 80 | inner_y = FFBlock( 81 | expand_ratio=self.inner_expand_ratio, 82 | dropout_rate=self.dropout_rate, 83 | dtype=self.dtype, 84 | )(inner_y, is_training=is_training) 85 | inner_output = inner_x + inner_y 86 | 87 | outer_x = Inner2OuterBlock(dtype=self.dtype)(patch_inputs, inner_output) 88 | 89 | outer_x = nn.LayerNorm(dtype=self.dtype)(outer_x) 90 | outer_x = SelfAttentionBlock( 91 | num_heads=self.outer_num_heads, 92 | attn_dropout_rate=self.attn_dropout_rate, 93 | out_dropout_rate=self.dropout_rate, 94 | rotary_qk=self.rotary_qk, 95 | rotary_v=self.rotary_v, 96 | dtype=self.dtype, 97 | )(outer_x, is_training=is_training) 98 | outer_x = outer_x + patch_inputs 99 | outer_y = nn.LayerNorm(dtype=self.dtype)(outer_x) 100 | outer_y = FFBlock( 101 | expand_ratio=self.outer_expand_ratio, 102 | dropout_rate=self.dropout_rate, 103 | dtype=self.dtype, 104 | )(outer_y, is_training=is_training) 105 | outer_output = outer_x + outer_y 106 | 107 | return outer_output, inner_output 108 | 109 | 110 | class Encoder(nn.Module): 111 | num_layers: int 112 | inner_num_heads: int 113 | outer_num_heads: int 114 | inner_expand_ratio: float = 4 115 | outer_expand_ratio: float = 4 116 | attn_dropout_rate: float = 0.0 117 | dropout_rate: float = 0.0 118 | activation_fn: Callable = nn.activation.gelu 119 | rotary_qk: bool = False 120 | rotary_v: bool = False 121 | dtype: jnp.dtype = jnp.float32 122 | 123 | @nn.compact 124 | def __call__(self, patch_embeddings, pixel_embeddings, is_training: bool): 125 | for _ in range(self.num_layers): 126 | patch_embeddings, pixel_embeddings = EncoderBlock( 127 | inner_num_heads=self.inner_num_heads, 128 | outer_num_heads=self.outer_num_heads, 129 | attn_dropout_rate=self.attn_dropout_rate, 130 | dropout_rate=self.dropout_rate, 131 | activation_fn=self.activation_fn, 132 | rotary_qk=self.rotary_qk, 133 | rotary_v=self.rotary_v, 134 | dtype=self.dtype, 135 | )(patch_embeddings, pixel_embeddings, is_training=is_training) 136 | 137 | output = patch_embeddings 138 | return output 139 | 140 | 141 | class TNT(nn.Module): 142 | output_dim: int 143 | num_layers: int 144 | inner_num_heads: int 145 | outer_num_heads: int 146 | inner_embed_dim: int 147 | outer_embed_dim: int 148 | patch_shape: Tuple[int] 149 | transformed_patch_shape: Tuple[int] 150 | inner_expand_ratio: float = 4 151 | outer_expand_ratio: float = 4 152 | attn_dropout_rate: float = 0.0 153 | dropout_rate: float = 0.0 154 | activation_fn: Callable = nn.activation.gelu 155 | rotary_qk: bool = False 156 | rotary_v: bool = False 157 | dtype: jnp.dtype = jnp.float32 158 | 159 | @nn.compact 160 | def __call__(self, inputs, is_training: bool): 161 | pixel_embeddings = PixelEmbedBlock( 162 | patch_shape=self.patch_shape, 163 | transformed_patch_shape=self.transformed_patch_shape, 164 | embed_dim=self.inner_embed_dim, 165 | dtype=self.dtype, 166 | )(inputs) 167 | 168 | patch_embeddings = PatchEmbedBlock( 169 | patch_shape=self.patch_shape, 170 | embed_dim=self.outer_embed_dim, 171 | use_bias=True, 172 | dtype=self.dtype, 173 | )(inputs) 174 | 175 | b, l, _ = patch_embeddings.shape 176 | cls_shape = (1, 1, self.outer_embed_dim) 177 | cls_token = self.param("cls", nn.initializers.zeros, cls_shape) 178 | cls_token = jnp.tile(cls_token, [b, 1, 1]) 179 | patch_embeddings = jnp.concatenate([cls_token, patch_embeddings], axis=1) 180 | 181 | if not self.rotary_qk and not self.rotary_v: 182 | pixel_embeddings = AddAbsPosEmbed()(pixel_embeddings) 183 | patch_embeddings = AddAbsPosEmbed()(patch_embeddings) 184 | else: 185 | pass 186 | 187 | patch_embeddings = nn.Dropout(rate=self.dropout_rate)( 188 | patch_embeddings, deterministic=not is_training 189 | ) 190 | 191 | patch_embeddings = Encoder( 192 | num_layers=self.num_layers, 193 | inner_num_heads=self.inner_num_heads, 194 | outer_num_heads=self.outer_num_heads, 195 | attn_dropout_rate=self.attn_dropout_rate, 196 | dropout_rate=self.dropout_rate, 197 | activation_fn=self.activation_fn, 198 | rotary_qk=self.rotary_qk, 199 | rotary_v=self.rotary_v, 200 | dtype=self.dtype, 201 | )(patch_embeddings, pixel_embeddings, is_training=is_training) 202 | 203 | cls_token = patch_embeddings[:, 0] 204 | output = nn.Dense( 205 | features=self.output_dim, 206 | dtype=self.dtype, 207 | kernel_init=nn.initializers.orthogonal(), 208 | )(cls_token) 209 | return output 210 | -------------------------------------------------------------------------------- /clap/trunks/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | from flax import linen as nn 4 | from jax import numpy as jnp 5 | 6 | from ..layers import SelfAttentionBlock, FFBlock, AddAbsPosEmbed 7 | 8 | 9 | class EncoderBlock(nn.Module): 10 | num_heads: int 11 | expand_ratio: float = 4 12 | attn_dropout_rate: float = 0.0 13 | dropout_rate: float = 0.0 14 | activation_fn: Callable = nn.activation.gelu 15 | rotary_qk: bool = False 16 | rotary_v: bool = False 17 | dtype: jnp.dtype = jnp.float32 18 | 19 | @nn.compact 20 | def __call__(self, inputs, is_training: bool): 21 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 22 | x = SelfAttentionBlock( 23 | num_heads=self.num_heads, 24 | attn_dropout_rate=self.attn_dropout_rate, 25 | out_dropout_rate=self.dropout_rate, 26 | rotary_qk=self.rotary_qk, 27 | rotary_v=self.rotary_v, 28 | dtype=self.dtype, 29 | )(x, is_training=is_training) 30 | x = x + inputs 31 | 32 | y = nn.LayerNorm(dtype=self.dtype)(x) 33 | y = FFBlock( 34 | expand_ratio=self.expand_ratio, 35 | dropout_rate=self.dropout_rate, 36 | activation_fn=self.activation_fn, 37 | dtype=self.dtype, 38 | )(y, is_training=is_training) 39 | output = x + y 40 | return output 41 | 42 | 43 | class Encoder(nn.Module): 44 | num_layers: int 45 | num_heads: int 46 | expand_ratio: float = 4 47 | attn_dropout_rate: float = 0.0 48 | dropout_rate: float = 0.0 49 | activation_fn: Callable = nn.activation.gelu 50 | rotary_qk: bool = False 51 | rotary_v: bool = False 52 | dtype: jnp.dtype = jnp.float32 53 | 54 | @nn.compact 55 | def __call__(self, inputs, is_training: bool): 56 | if not self.rotary_qk and not self.rotary_v: 57 | x = AddAbsPosEmbed()(inputs) 58 | else: 59 | x = inputs 60 | 61 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) 62 | 63 | for _ in range(self.num_layers): 64 | x = EncoderBlock( 65 | num_heads=self.num_heads, 66 | expand_ratio=self.expand_ratio, 67 | attn_dropout_rate=self.attn_dropout_rate, 68 | dropout_rate=self.dropout_rate, 69 | activation_fn=self.activation_fn, 70 | rotary_qk=self.rotary_qk, 71 | rotary_v=self.rotary_v, 72 | dtype=self.dtype, 73 | )(x, is_training=is_training) 74 | 75 | output = nn.LayerNorm(dtype=self.dtype)(x) 76 | return output 77 | 78 | 79 | class Transformer(nn.Module): 80 | output_dim: int 81 | num_layers: int 82 | num_heads: int 83 | embed_dim: int 84 | expand_ratio: float = 4 85 | attn_dropout_rate: float = 0.0 86 | dropout_rate: float = 0.0 87 | activation_fn: Callable = nn.activation.gelu 88 | rotary_qk: bool = False 89 | rotary_v: bool = False 90 | dtype: jnp.dtype = jnp.float32 91 | 92 | @nn.compact 93 | def __call__(self, inputs, is_training: bool): 94 | assert self.embed_dim % self.num_heads == 0 95 | 96 | x = inputs 97 | 98 | b, l, _ = x.shape 99 | cls_shape = (1, 1, self.embed_dim) 100 | cls_token = self.param("cls", nn.initializers.zeros, cls_shape) 101 | cls_token = jnp.tile(cls_token, [b, 1, 1]) 102 | x = jnp.concatenate([cls_token, x], axis=1) 103 | 104 | x = Encoder( 105 | num_layers=self.num_layers, 106 | num_heads=self.num_heads, 107 | expand_ratio=self.expand_ratio, 108 | attn_dropout_rate=self.attn_dropout_rate, 109 | dropout_rate=self.dropout_rate, 110 | activation_fn=self.activation_fn, 111 | rotary_qk=self.rotary_qk, 112 | rotary_v=self.rotary_v, 113 | dtype=self.dtype, 114 | )(x, is_training=is_training) 115 | 116 | cls_token = x[:, 0] 117 | output = nn.Dense( 118 | features=self.output_dim, 119 | dtype=self.dtype, 120 | kernel_init=nn.initializers.orthogonal(), 121 | )(cls_token) 122 | return output 123 | -------------------------------------------------------------------------------- /clap/trunks/vit.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | from flax import linen as nn 4 | from jax import numpy as jnp 5 | 6 | from ..layers import SelfAttentionBlock, FFBlock, AddAbsPosEmbed, PatchEmbedBlock 7 | 8 | 9 | class EncoderBlock(nn.Module): 10 | num_heads: int 11 | expand_ratio: float = 4 12 | attn_dropout_rate: float = 0.0 13 | dropout_rate: float = 0.0 14 | activation_fn: Callable = nn.activation.gelu 15 | rotary_qk: bool = False 16 | rotary_v: bool = False 17 | dtype: jnp.dtype = jnp.float32 18 | 19 | @nn.compact 20 | def __call__(self, inputs, is_training: bool): 21 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 22 | x = SelfAttentionBlock( 23 | num_heads=self.num_heads, 24 | attn_dropout_rate=self.attn_dropout_rate, 25 | out_dropout_rate=self.dropout_rate, 26 | rotary_qk=self.rotary_qk, 27 | rotary_v=self.rotary_v, 28 | dtype=self.dtype, 29 | )(x, is_training=is_training) 30 | x = x + inputs 31 | 32 | y = nn.LayerNorm(dtype=self.dtype)(x) 33 | y = FFBlock( 34 | expand_ratio=self.expand_ratio, 35 | dropout_rate=self.dropout_rate, 36 | activation_fn=self.activation_fn, 37 | dtype=self.dtype, 38 | )(y, is_training=is_training) 39 | output = x + y 40 | return output 41 | 42 | 43 | class Encoder(nn.Module): 44 | num_layers: int 45 | num_heads: int 46 | expand_ratio: float = 4 47 | attn_dropout_rate: float = 0.0 48 | dropout_rate: float = 0.0 49 | activation_fn: Callable = nn.activation.gelu 50 | rotary_qk: bool = False 51 | rotary_v: bool = False 52 | dtype: jnp.dtype = jnp.float32 53 | 54 | @nn.compact 55 | def __call__(self, inputs, is_training: bool): 56 | if not self.rotary_qk and not self.rotary_v: 57 | x = AddAbsPosEmbed()(inputs) 58 | else: 59 | x = inputs 60 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) 61 | 62 | for _ in range(self.num_layers): 63 | x = EncoderBlock( 64 | num_heads=self.num_heads, 65 | expand_ratio=self.expand_ratio, 66 | attn_dropout_rate=self.attn_dropout_rate, 67 | dropout_rate=self.dropout_rate, 68 | activation_fn=self.activation_fn, 69 | rotary_qk=self.rotary_qk, 70 | rotary_v=self.rotary_v, 71 | dtype=self.dtype, 72 | )(x, is_training=is_training) 73 | 74 | output = nn.LayerNorm(dtype=self.dtype)(x) 75 | return output 76 | 77 | 78 | class ViT(nn.Module): 79 | output_dim: int 80 | num_layers: int 81 | num_heads: int 82 | embed_dim: int 83 | patch_shape: Tuple[int] 84 | expand_ratio: float = 4 85 | attn_dropout_rate: float = 0.0 86 | dropout_rate: float = 0.0 87 | activation_fn: Callable = nn.activation.gelu 88 | rotary_qk: bool = False 89 | rotary_v: bool = False 90 | dtype: jnp.dtype = jnp.float32 91 | 92 | @nn.compact 93 | def __call__(self, inputs, is_training: bool): 94 | assert self.embed_dim % self.num_heads == 0 95 | 96 | x = PatchEmbedBlock( 97 | patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype 98 | )(inputs) 99 | 100 | b, l, _ = x.shape 101 | cls_shape = (1, 1, self.embed_dim) 102 | cls_token = self.param("cls", nn.initializers.zeros, cls_shape) 103 | cls_token = jnp.tile(cls_token, [b, 1, 1]) 104 | x = jnp.concatenate([cls_token, x], axis=1) 105 | 106 | x = Encoder( 107 | num_layers=self.num_layers, 108 | num_heads=self.num_heads, 109 | expand_ratio=self.expand_ratio, 110 | attn_dropout_rate=self.attn_dropout_rate, 111 | dropout_rate=self.dropout_rate, 112 | activation_fn=self.activation_fn, 113 | rotary_qk=self.rotary_qk, 114 | rotary_v=self.rotary_v, 115 | dtype=self.dtype, 116 | )(x, is_training=is_training) 117 | 118 | cls_token = x[:, 0] 119 | output = nn.Dense( 120 | features=self.output_dim, 121 | dtype=self.dtype, 122 | kernel_init=nn.initializers.orthogonal(), 123 | )(cls_token) 124 | return output 125 | -------------------------------------------------------------------------------- /configs/model/audio/cait.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: 'cait' 3 | depth: 8 4 | token_only_depth: 2 5 | dim: 512 6 | heads: 8 7 | patch_shape: 8 | - 4 9 | - 80 10 | projection_dim: 512 11 | rotary_qk: True 12 | stochastic_depth_rate: 0.4 13 | layerscale_eps: 0.000001 -------------------------------------------------------------------------------- /configs/model/audio/mixer.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: 'mixer' 3 | depth: 8 4 | dim: 512 5 | patch_shape: 6 | - 4 7 | - 80 8 | projection_dim: 512 -------------------------------------------------------------------------------- /configs/model/audio/tnt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: 'tnt' 3 | depth: 8 4 | outer: 5 | dim: 512 6 | heads: 8 7 | patch_shape: 8 | - 4 9 | - 80 10 | inner: 11 | dim: 128 12 | heads: 4 13 | patch_shape: 14 | - 1 15 | - 80 16 | projection_dim: 512 17 | rotary_qk: True -------------------------------------------------------------------------------- /configs/model/audio/vit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: 'vit' 3 | depth: 8 4 | dim: 512 5 | heads: 8 6 | patch_shape: 7 | - 4 8 | - 80 9 | projection_dim: 512 10 | rotary_qk: True -------------------------------------------------------------------------------- /configs/model/text/transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: 'transformer' 3 | depth: 8 4 | dim: 512 5 | heads: 8 6 | vocab: 256 7 | projection_dim: 512 8 | rotary_qk: True -------------------------------------------------------------------------------- /configs/optimizer/standard.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | learning_rate: 0.0003 3 | weight_decay: 0.1 4 | max_norm: 0.5 -------------------------------------------------------------------------------- /configs/preprocessing/dataset/commonvoice.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_folder: "./data" 3 | tsv_filename: 'subset.tsv' 4 | mel_bins: 80 -------------------------------------------------------------------------------- /configs/training/standard.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_folder: "./data" 3 | batch_size: 16 4 | epochs: 100 5 | seed: 0 -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from clap.datasets import PairTextSpectrogramTFRecords 3 | from omegaconf import DictConfig, OmegaConf 4 | import hydra 5 | 6 | import torchaudio 7 | 8 | 9 | # helpers 10 | 11 | 12 | def tsv_to_dict(path): 13 | with open(path) as fd: 14 | rd = csv.DictReader(fd, delimiter="\t", quotechar='"') 15 | return [row for row in rd] 16 | 17 | 18 | # script 19 | 20 | 21 | @hydra.main(config_path="configs") 22 | def preprocess(cfg: DictConfig) -> None: 23 | data_folder = ( 24 | hydra.utils.get_original_cwd() + "/" + cfg.preprocessing.dataset.data_folder 25 | ) 26 | 27 | voice_clips = tsv_to_dict(f"{data_folder}/{cfg.preprocessing.dataset.tsv_filename}") 28 | 29 | def extract_spectrogram(filename): 30 | waveform, sample_rate = torchaudio.load(f"{data_folder}/clips/{filename}") 31 | 32 | output = torchaudio.transforms.MelSpectrogram( 33 | sample_rate, n_mels=cfg.preprocessing.dataset.mel_bins, f_min=0, f_max=8000 34 | )(waveform)[0] 35 | 36 | return output.t().numpy() 37 | 38 | spectrograms = (extract_spectrogram(clip["path"]) for clip in voice_clips) 39 | captions = (clip["sentence"] for clip in voice_clips) 40 | save_path = data_folder + "/data.tfrecord" 41 | PairTextSpectrogramTFRecords.write(spectrograms, captions, fname=save_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | preprocess() 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="clap-jax", 5 | packages=find_packages(), 6 | version="0.0.1", 7 | license="MIT", 8 | description="CLAP - Contrastive Language-Audio Pretraining", 9 | author="Charles Foster", 10 | author_email="", 11 | url="https://github.com/cfoster0/CLAP", 12 | keywords=[ 13 | "artificial intelligence", 14 | "deep learning", 15 | "contrastive learning", 16 | "audio", 17 | ], 18 | install_requires=[ 19 | "einops>=0.3", 20 | "flax", 21 | "hydra-core", 22 | "jax", 23 | "jaxlib", 24 | "lm_dataformat", 25 | "optax", 26 | "tensorflow", 27 | "torch", 28 | "torchaudio", 29 | "tqdm", 30 | ], 31 | classifiers=[ 32 | "Development Status :: 4 - Beta", 33 | "Intended Audience :: Developers", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | "License :: OSI Approved :: MIT License", 36 | "Programming Language :: Python :: 3.6", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, OmegaConf 2 | import hydra 3 | 4 | import jax 5 | from jax import random, numpy as np, value_and_grad, jit, tree_util 6 | from optax import ( 7 | chain, 8 | clip_by_global_norm, 9 | scale_by_adam, 10 | scale, 11 | apply_updates, 12 | add_decayed_weights, 13 | masked, 14 | ) 15 | 16 | from clap.models import CLAP 17 | 18 | # data 19 | 20 | from clap.datasets import PairTextSpectrogramTFRecords 21 | 22 | 23 | @hydra.main(config_path="configs") 24 | def train(cfg: DictConfig) -> None: 25 | 26 | print(OmegaConf.to_yaml(cfg)) 27 | 28 | # rng 29 | 30 | rng_key = random.PRNGKey(cfg.training.seed) 31 | 32 | # data 33 | 34 | training_data_path = hydra.utils.get_original_cwd() + "/" + cfg.training.data_folder 35 | dataloader = PairTextSpectrogramTFRecords( 36 | training_data_path, 37 | cfg.training.batch_size, 38 | ) 39 | 40 | # model 41 | 42 | model = CLAP( 43 | text_config=cfg.model.text, 44 | audio_config=cfg.model.audio, 45 | ) 46 | 47 | # optimizer 48 | 49 | exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1, params) 50 | 51 | optim = chain( 52 | clip_by_global_norm(cfg.optimizer.max_norm), 53 | scale_by_adam(eps=1e-4), 54 | add_decayed_weights(cfg.optimizer.weight_decay, exclude_bias), 55 | scale(-cfg.optimizer.learning_rate), 56 | ) 57 | 58 | # init 59 | 60 | batch = next(iter(dataloader)) 61 | 62 | text = batch["text"] 63 | audio = batch["audio"] 64 | 65 | params = model.init(rng_key, text, audio) 66 | optim_state = optim.init(params) 67 | 68 | # loss function, for use with value_and_grad 69 | 70 | @jit 71 | @value_and_grad 72 | def loss_fn(params, text, audio): 73 | return model.apply( 74 | params, 75 | text, 76 | audio, 77 | return_loss=True, 78 | is_training=True, 79 | ) 80 | 81 | # train loop 82 | 83 | for _ in range(cfg.training.epochs): 84 | for batch in dataloader: 85 | text = batch["text"] 86 | audio = batch["audio"] 87 | loss, grads = loss_fn(params, text, audio) 88 | updates, optim_state = optim.update(grads, optim_state, params) 89 | params = apply_updates(params, updates) 90 | print(f"loss: {loss}") 91 | 92 | # finished 93 | 94 | 95 | if __name__ == "__main__": 96 | train() 97 | --------------------------------------------------------------------------------