├── .DS_Store ├── LICENSE ├── README.md ├── images ├── ats.png ├── cait.png ├── cross_vit.png ├── crossformer.png ├── crossformer2.png ├── cvt.png ├── dino.png ├── distill.png ├── levit.png ├── mae.png ├── mbvit.png ├── nest.png ├── parallel-vit.png ├── patch_merger.png ├── pit.png ├── regionvit.png ├── regionvit2.png ├── scalable-vit-1.png ├── scalable-vit-2.png ├── simmim.png ├── t2t.png ├── twins_svt.png ├── vit.gif └── vit_for_small_datasets.png └── vit_tensorflow ├── ats_vit.py ├── cait.py ├── cct.py ├── cross_vit.py ├── crossformer.py ├── cvt.py ├── deepvit.py ├── distill.py ├── efficient.py ├── levit.py ├── mae.py ├── mobile_vit.py ├── mpp.py ├── nest.py ├── parallel_vit.py ├── pit.py ├── regionvit.py ├── scalable_vit.py ├── simmim.py ├── t2t.py ├── twins_svt.py ├── vit.py ├── vit_for_small_dataset.py └── vit_with_patch_merger.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /images/ats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/ats.png -------------------------------------------------------------------------------- /images/cait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/cait.png -------------------------------------------------------------------------------- /images/cross_vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/cross_vit.png -------------------------------------------------------------------------------- /images/crossformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/crossformer.png -------------------------------------------------------------------------------- /images/crossformer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/crossformer2.png -------------------------------------------------------------------------------- /images/cvt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/cvt.png -------------------------------------------------------------------------------- /images/dino.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/dino.png -------------------------------------------------------------------------------- /images/distill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/distill.png -------------------------------------------------------------------------------- /images/levit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/levit.png -------------------------------------------------------------------------------- /images/mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/mae.png -------------------------------------------------------------------------------- /images/mbvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/mbvit.png -------------------------------------------------------------------------------- /images/nest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/nest.png -------------------------------------------------------------------------------- /images/parallel-vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/parallel-vit.png -------------------------------------------------------------------------------- /images/patch_merger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/patch_merger.png -------------------------------------------------------------------------------- /images/pit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/pit.png -------------------------------------------------------------------------------- /images/regionvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/regionvit.png -------------------------------------------------------------------------------- /images/regionvit2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/regionvit2.png -------------------------------------------------------------------------------- /images/scalable-vit-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/scalable-vit-1.png -------------------------------------------------------------------------------- /images/scalable-vit-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/scalable-vit-2.png -------------------------------------------------------------------------------- /images/simmim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/simmim.png -------------------------------------------------------------------------------- /images/t2t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/t2t.png -------------------------------------------------------------------------------- /images/twins_svt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/twins_svt.png -------------------------------------------------------------------------------- /images/vit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/vit.gif -------------------------------------------------------------------------------- /images/vit_for_small_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/vit-tensorflow/f16972989d9df478d7ee3ab1c45332d412801646/images/vit_for_small_datasets.png -------------------------------------------------------------------------------- /vit_tensorflow/ats_vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow import einsum 7 | from tensorflow.keras.preprocessing.sequence import pad_sequences 8 | 9 | from einops import rearrange, repeat 10 | from einops.layers.tensorflow import Rearrange 11 | 12 | import numpy as np 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def pair(t): 18 | return t if isinstance(t, tuple) else (t, t) 19 | 20 | # adaptive token sampling functions and classes 21 | 22 | def log(t, eps = 1e-6): 23 | return tf.math.log(t + eps) 24 | 25 | def sample_gumbel(shape, dtype, eps = 1e-6): 26 | u = tf.random.uniform(shape, dtype=dtype) 27 | return -log(-log(u, eps), eps) 28 | 29 | def torch_gather(x, indices, gather_axis): 30 | # if pytorch gather indices are 31 | # [[[0, 10, 20], [0, 10, 20], [0, 10, 20]], 32 | # [[0, 10, 20], [0, 10, 20], [0, 10, 20]]] 33 | # tf nd_gather needs to be 34 | # [[0,0,0], [0,0,10], [0,0,20], [0,1,0], [0,1,10], [0,1,20], [0,2,0], [0,2,10], [0,2,20], 35 | # [1,0,0], [1,0,10], [1,0,20], [1,1,0], [1,1,10], [1,1,20], [1,2,0], [1,2,10], [1,2,20]] 36 | 37 | indices = tf.cast(indices, tf.int64) 38 | # create a tensor containing indices of each element 39 | all_indices = tf.where(tf.fill(indices.shape, True)) 40 | gather_locations = tf.reshape(indices, [indices.shape.num_elements()]) 41 | 42 | # splice in our pytorch style index at the correct axis 43 | gather_indices = [] 44 | for axis in range(len(indices.shape)): 45 | if axis == gather_axis: 46 | gather_indices.append(gather_locations) 47 | else: 48 | gather_indices.append(all_indices[:, axis]) 49 | 50 | gather_indices = tf.stack(gather_indices, axis=-1) 51 | gathered = tf.gather_nd(x, gather_indices) 52 | reshaped = tf.reshape(gathered, indices.shape) 53 | return reshaped 54 | 55 | def batched_index_select(values, indices, dim = 1): 56 | value_dims = values.shape[(dim + 1):] 57 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) 58 | indices = indices[(..., *((None,) * len(value_dims)))] 59 | indices = tf.tile(indices, multiples=[1] * len(indices_shape) + [*value_dims]) 60 | value_expand_len = len(indices_shape) - (dim + 1) 61 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] 62 | 63 | value_expand_shape = [-1] * len(values.shape) 64 | expand_slice = slice(dim, (dim + value_expand_len)) 65 | value_expand_shape[expand_slice] = indices.shape[expand_slice] 66 | dim += value_expand_len 67 | 68 | values = torch_gather(values, indices, dim) 69 | return values 70 | 71 | class AdaptiveTokenSampling(Layer): 72 | def __init__(self, output_num_tokens, eps=1e-6): 73 | super(AdaptiveTokenSampling, self).__init__() 74 | self.eps = eps 75 | self.output_num_tokens = output_num_tokens 76 | 77 | def call(self, attn, value=None, mask=None, training=True): 78 | heads, output_num_tokens, eps, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.dtype 79 | 80 | # first get the attention values for CLS token to all other tokens 81 | cls_attn = attn[..., 0, 1:] 82 | 83 | # calculate the norms of the values, for weighting the scores, as described in the paper 84 | value_norms = tf.norm(value[..., 1:, :], axis=-1) 85 | 86 | # weigh the attention scores by the norm of the values, sum across all heads 87 | cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms) 88 | 89 | # normalize to 1 90 | normed_cls_attn = cls_attn / (tf.reduce_sum(cls_attn, axis=-1, keepdims=True) + eps) 91 | 92 | # instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead 93 | pseudo_logits = log(normed_cls_attn) 94 | 95 | # mask out pseudo logits for gumbel-max sampling 96 | mask_without_cls = mask[:, 1:] 97 | mask_value = -np.finfo(attn.dtype.as_numpy_dtype).max / 2 98 | pseudo_logits = tf.where(~mask_without_cls, mask_value, pseudo_logits) 99 | 100 | # expand k times, k being the adaptive sampling number 101 | pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k=output_num_tokens) 102 | pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, dtype=dtype) 103 | 104 | # gumble-max and add one to reserve 0 for padding / mask 105 | sampled_token_ids = tf.argmax(pseudo_logits, axis=-1) + 1 106 | 107 | # calculate unique using torch.unique and then pad the sequence from the right 108 | unique_sampled_token_ids_list = [] 109 | for t in tf.unstack(sampled_token_ids): 110 | t = tf.cast(t, tf.int32) 111 | t, _ = tf.unique(t) 112 | x = tf.sort(t) 113 | unique_sampled_token_ids_list.append(x) 114 | 115 | 116 | unique_sampled_token_ids = pad_sequences(unique_sampled_token_ids_list) 117 | 118 | # calculate the new mask, based on the padding 119 | new_mask = unique_sampled_token_ids != 0 120 | 121 | # CLS token never gets masked out (gets a value of True) 122 | new_mask = tf.pad(new_mask, paddings=[[0, 0], [1, 0]], constant_values=True) 123 | 124 | # prepend a 0 token id to keep the CLS attention scores 125 | unique_sampled_token_ids = tf.pad(unique_sampled_token_ids, paddings=[[0, 0], [1, 0]]) 126 | expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h=heads) 127 | 128 | # gather the new attention scores 129 | new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim=2) 130 | 131 | # return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual) 132 | return new_attn, new_mask, unique_sampled_token_ids 133 | 134 | def gelu(x, approximate=False): 135 | if approximate: 136 | coeff = tf.cast(0.044715, x.dtype) 137 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 138 | else: 139 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 140 | 141 | class GELU(Layer): 142 | def __init__(self, approximate=False): 143 | super(GELU, self).__init__() 144 | self.approximate = approximate 145 | 146 | def call(self, x, training=True): 147 | return gelu(x, self.approximate) 148 | 149 | class PreNorm(Layer): 150 | def __init__(self, fn): 151 | super(PreNorm, self).__init__() 152 | 153 | self.norm = nn.LayerNormalization() 154 | self.fn = fn 155 | 156 | def call(self, x, **kwargs): 157 | return self.fn(self.norm(x), **kwargs) 158 | 159 | class MLP(Layer): 160 | def __init__(self, dim, hidden_dim, dropout=0.0): 161 | super(MLP, self).__init__() 162 | self.net = Sequential([ 163 | nn.Dense(units=hidden_dim), 164 | GELU(), 165 | nn.Dropout(rate=dropout), 166 | nn.Dense(units=dim), 167 | nn.Dropout(rate=dropout) 168 | ]) 169 | 170 | def call(self, x, training=True): 171 | return self.net(x, training=training) 172 | 173 | class Attention(Layer): 174 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, output_num_tokens=None): 175 | super(Attention, self).__init__() 176 | inner_dim = dim_head * heads 177 | self.heads = heads 178 | self.scale = dim_head ** -0.5 179 | 180 | self.attend = nn.Softmax() 181 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 182 | 183 | self.output_num_tokens = output_num_tokens 184 | self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None 185 | 186 | self.to_out = Sequential([ 187 | nn.Dense(units=dim), 188 | nn.Dropout(rate=dropout) 189 | ]) 190 | 191 | def call(self, x, mask=None, training=True): 192 | num_tokens = x.shape[1] 193 | 194 | qkv = self.to_qkv(x) 195 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 196 | q, k, v = map(lambda t: rearrange(t, 'b n (h d)-> b h n d', h=self.heads), qkv) 197 | 198 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 199 | 200 | if exists(mask): 201 | mask_f = tf.cast(mask, tf.float32) 202 | dots_mask = rearrange(mask_f, 'b i -> b 1 i 1') * rearrange(mask_f, 'b j -> b 1 1 j') 203 | dots_mask = tf.cast(dots_mask, tf.bool) 204 | mask_value = -np.finfo(dots.dtype.as_numpy_dtype).max 205 | dots = tf.where(~dots_mask, mask_value, dots) 206 | 207 | attn = self.attend(dots) 208 | 209 | sampled_token_ids = None 210 | 211 | # if adaptive token sampling is enabled 212 | # and number of tokens is greater than the number of output tokens 213 | if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens: 214 | attn, mask, sampled_token_ids = self.ats(attn, v, mask=mask) 215 | 216 | out = tf.matmul(attn, v) 217 | out = rearrange(out, 'b h n d -> b n (h d)') 218 | out = self.to_out(out, training=training) 219 | 220 | return out, mask, sampled_token_ids 221 | 222 | class Transformer(Layer): 223 | def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout=0.0): 224 | super(Transformer, self).__init__() 225 | assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer' 226 | assert sorted(max_tokens_per_depth, reverse=True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order' 227 | assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer' 228 | 229 | self.layers = [] 230 | for _, output_num_tokens in zip(range(depth), max_tokens_per_depth): 231 | self.layers.append([ 232 | PreNorm(Attention(dim, output_num_tokens=output_num_tokens, heads=heads, dim_head=dim_head, dropout=dropout)), 233 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 234 | ]) 235 | 236 | def call(self, x, training=True): 237 | b, n = x.shape[:2] 238 | 239 | # use mask to keep track of the paddings when sampling tokens 240 | # as the duplicates (when sampling) are just removed, as mentioned in the paper 241 | mask = tf.ones([b, n], dtype=tf.bool) 242 | 243 | token_ids = tf.range(n) 244 | token_ids = repeat(token_ids, 'n -> b n', b = b) 245 | 246 | for attn, ff in self.layers: 247 | attn_out, mask, sampled_token_ids = attn(x, mask=mask, training=training) 248 | 249 | # when token sampling, one needs to then gather the residual tokens with the sampled token ids 250 | if exists(sampled_token_ids): 251 | x = batched_index_select(x, sampled_token_ids, dim=1) 252 | token_ids = batched_index_select(token_ids, sampled_token_ids, dim=1) 253 | 254 | x = x + attn_out 255 | 256 | x = ff(x, training=training) + x 257 | 258 | return x, token_ids 259 | 260 | class ViT(Model): 261 | def __init__(self, 262 | image_size, 263 | patch_size, 264 | num_classes, 265 | dim, 266 | depth, 267 | max_tokens_per_depth, 268 | heads, 269 | mlp_dim, 270 | dim_head=64, 271 | dropout=0.0, 272 | emb_dropout=0.0 273 | ): 274 | super(ViT, self).__init__() 275 | 276 | image_height, image_width = pair(image_size) 277 | patch_height, patch_width = pair(patch_size) 278 | 279 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 280 | 281 | num_patches = (image_height // patch_height) * (image_width // patch_width) 282 | 283 | self.patch_embedding = Sequential([ 284 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 285 | nn.Dense(units=dim) 286 | ]) 287 | 288 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 289 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 290 | self.dropout = nn.Dropout(rate=emb_dropout) 291 | 292 | self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout) 293 | 294 | self.mlp_head = Sequential([ 295 | nn.LayerNormalization(), 296 | nn.Dense(units=num_classes) 297 | ]) 298 | 299 | 300 | def call(self, img, return_sampled_token_ids=False, training=True, **kwargs): 301 | x = self.patch_embedding(img) 302 | b, n, _ = x.shape 303 | 304 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 305 | x = tf.concat([cls_tokens, x], axis=1) 306 | x += self.pos_embedding[:, :(n + 1)] 307 | x = self.dropout(x, training=training) 308 | 309 | x, token_ids = self.transformer(x, training=training) 310 | 311 | logits = self.mlp_head(x[:, 0]) 312 | 313 | if return_sampled_token_ids: 314 | # remove CLS token and decrement by 1 to make -1 the padding 315 | token_ids = token_ids[:, 1:] - 1 316 | return logits, token_ids 317 | 318 | return logits 319 | 320 | v = ViT( 321 | image_size = 256, 322 | patch_size = 16, 323 | num_classes = 1000, 324 | dim = 1024, 325 | depth = 6, 326 | max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling 327 | heads = 16, 328 | mlp_dim = 2048, 329 | dropout = 0.1, 330 | emb_dropout = 0.1 331 | ) 332 | 333 | img = tf.random.normal(shape=[4, 256, 256, 3]) 334 | preds = v(img) # (1, 1000) 335 | print(preds.shape) -------------------------------------------------------------------------------- /vit_tensorflow/cait.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import einsum 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.layers import Layer 5 | from tensorflow.keras import Sequential 6 | import tensorflow.keras.layers as nn 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.tensorflow import Rearrange 10 | 11 | from random import randrange 12 | import numpy as np 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def dropout_layers(layers, dropout): 18 | if dropout == 0: 19 | return layers 20 | 21 | num_layers = len(layers) 22 | 23 | to_drop = np.random.uniform(low=0.0, high=1.0, size=[num_layers]) < dropout 24 | 25 | # make sure at least one layer makes it 26 | if all(to_drop): 27 | rand_index = randrange(num_layers) 28 | to_drop[rand_index] = False 29 | 30 | layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] 31 | return layers 32 | 33 | class LayerScale(Layer): 34 | def __init__(self, dim, fn, depth): 35 | super(LayerScale, self).__init__() 36 | if depth <= 18: # epsilon detailed in section 2 of paper 37 | init_eps = 0.1 38 | elif depth > 18 and depth <= 24: 39 | init_eps = 1e-5 40 | else: 41 | init_eps = 1e-6 42 | 43 | scale = tf.fill(dims=[1, 1, dim], value=init_eps) 44 | self.scale = tf.Variable(scale) 45 | self.fn = fn 46 | 47 | def call(self, x, training=True, **kwargs): 48 | return self.fn(x, training=training, **kwargs) * self.scale 49 | 50 | class PreNorm(Layer): 51 | def __init__(self, fn): 52 | super(PreNorm, self).__init__() 53 | 54 | self.norm = nn.LayerNormalization() 55 | self.fn = fn 56 | 57 | def call(self, x, training=True, **kwargs): 58 | return self.fn(self.norm(x), training=training, **kwargs) 59 | 60 | class MLP(Layer): 61 | def __init__(self, dim, hidden_dim, dropout=0.0): 62 | super(MLP, self).__init__() 63 | def GELU(): 64 | def gelu(x, approximate=False): 65 | if approximate: 66 | coeff = tf.cast(0.044715, x.dtype) 67 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 68 | else: 69 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 70 | 71 | return nn.Activation(gelu) 72 | 73 | self.net = [ 74 | nn.Dense(units=hidden_dim), 75 | GELU(), 76 | nn.Dropout(rate=dropout), 77 | nn.Dense(units=dim), 78 | nn.Dropout(rate=dropout) 79 | ] 80 | self.net = Sequential(self.net) 81 | 82 | def call(self, x, training=True): 83 | return self.net(x, training=training) 84 | 85 | class Attention(Layer): 86 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 87 | super(Attention, self).__init__() 88 | inner_dim = dim_head * heads 89 | 90 | self.heads = heads 91 | self.scale = dim_head ** -0.5 92 | 93 | self.attend = nn.Softmax() 94 | self.to_q = nn.Dense(units=inner_dim, use_bias=False) 95 | self.to_kv = nn.Dense(units=inner_dim * 2, use_bias=False) 96 | 97 | self.mix_heads_pre_attn = tf.Variable(initial_value=tf.random.normal([heads, heads])) 98 | self.mix_heads_post_attn = tf.Variable(initial_value=tf.random.normal([heads, heads])) 99 | 100 | self.to_out = [ 101 | nn.Dense(units=dim), 102 | nn.Dropout(rate=dropout) 103 | ] 104 | 105 | self.to_out = Sequential(self.to_out) 106 | 107 | def call(self, x, context=None, training=True): 108 | 109 | if not exists(context): 110 | context = x 111 | else: 112 | context = tf.concat([x, context], axis=1) 113 | 114 | q = self.to_q(x) 115 | kv = self.to_kv(context) 116 | k, v = tf.split(kv, num_or_size_splits=2, axis=-1) 117 | qkv = (q, k, v) 118 | 119 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 120 | 121 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 122 | 123 | dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax 124 | attn = self.attend(dots) 125 | attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax 126 | 127 | x = tf.matmul(attn, v) 128 | x = rearrange(x, 'b h n d -> b n (h d)') 129 | x = self.to_out(x, training=training) 130 | 131 | return x 132 | 133 | class Transformer(Layer): 134 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0, layer_dropout=0.0): 135 | super(Transformer, self).__init__() 136 | 137 | self.layers = [] 138 | self.layer_dropout = layer_dropout 139 | 140 | for ind in range(depth): 141 | self.layers.append([ 142 | LayerScale(dim, PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), depth=ind+1), 143 | LayerScale(dim, PreNorm(MLP(dim, mlp_dim, dropout=dropout)), depth=ind+1) 144 | ]) 145 | 146 | def call(self, x, context=None, training=True): 147 | layers = dropout_layers(self.layers, dropout=self.layer_dropout) 148 | 149 | for attn, mlp in layers: 150 | x = attn(x, context=context, training=training) + x 151 | x = mlp(x, training=training) + x 152 | 153 | return x 154 | 155 | class CaiT(Model): 156 | def __init__(self, image_size, patch_size, num_classes, dim, depth, cls_depth, heads, mlp_dim, 157 | dim_head=64, dropout=0.0, emb_dropout=0.0, layer_dropout=0.0): 158 | super(CaiT, self).__init__() 159 | 160 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 161 | num_patches = (image_size // patch_size) ** 2 162 | 163 | self.patch_embedding = Sequential([ 164 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 165 | nn.Dense(units=dim) 166 | ], name='patch_embedding') 167 | 168 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches, dim])) 169 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 170 | self.dropout = nn.Dropout(rate=emb_dropout) 171 | 172 | self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout) 173 | self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout) 174 | 175 | self.mlp_head = Sequential([ 176 | nn.LayerNormalization(), 177 | nn.Dense(units=num_classes) 178 | ], name='mlp_head') 179 | 180 | def call(self, img, training=True, **kwargs): 181 | x = self.patch_embedding(img) 182 | b, n, d = x.shape 183 | 184 | x += self.pos_embedding[:, :n] 185 | x = self.dropout(x, training=training) 186 | 187 | x = self.patch_transformer(x, training=training) 188 | 189 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 190 | x = self.cls_transformer(cls_tokens, context=x, training=training) 191 | 192 | x = self.mlp_head(x[:, 0]) 193 | 194 | return x 195 | 196 | """ Usage 197 | v = CaiT( 198 | image_size = 256, 199 | patch_size = 32, 200 | num_classes = 1000, 201 | dim = 1024, 202 | depth = 12, # depth of transformer for patch to patch attention only 203 | cls_depth = 2, # depth of cross attention of CLS tokens to patch 204 | heads = 16, 205 | mlp_dim = 2048, 206 | dropout = 0.1, 207 | emb_dropout = 0.1, 208 | layer_dropout = 0.05 # randomly dropout 5% of the layers 209 | ) 210 | 211 | img = tf.random.normal(shape=[1, 256, 256, 3]) 212 | preds = v(img) # (1, 1000) 213 | """ 214 | -------------------------------------------------------------------------------- /vit_tensorflow/cct.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | # Pre-defined CCT Models 13 | __all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16'] 14 | 15 | 16 | def cct_2(*args, **kwargs): 17 | return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128, 18 | *args, **kwargs) 19 | 20 | 21 | def cct_4(*args, **kwargs): 22 | return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128, 23 | *args, **kwargs) 24 | 25 | 26 | def cct_6(*args, **kwargs): 27 | return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256, 28 | *args, **kwargs) 29 | 30 | 31 | def cct_7(*args, **kwargs): 32 | return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256, 33 | *args, **kwargs) 34 | 35 | 36 | def cct_8(*args, **kwargs): 37 | return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256, 38 | *args, **kwargs) 39 | 40 | 41 | def cct_14(*args, **kwargs): 42 | return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384, 43 | *args, **kwargs) 44 | 45 | 46 | def cct_16(*args, **kwargs): 47 | return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384, 48 | *args, **kwargs) 49 | 50 | 51 | def _cct(num_layers, num_heads, mlp_ratio, embedding_dim, 52 | kernel_size=3, stride=None, 53 | *args, **kwargs): 54 | stride = stride if stride is not None else max(1, (kernel_size // 2) - 1) 55 | return CCT(num_layers=num_layers, 56 | num_heads=num_heads, 57 | mlp_ratio=mlp_ratio, 58 | embedding_dim=embedding_dim, 59 | kernel_size=kernel_size, 60 | stride=stride, 61 | *args, **kwargs) 62 | 63 | 64 | def GELU(): 65 | def gelu(x, approximate=False): 66 | if approximate: 67 | coeff = tf.cast(0.044715, x.dtype) 68 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 69 | else: 70 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 71 | 72 | return nn.Activation(gelu) 73 | 74 | def drop_path(x, drop_prob=0.0, training=False): 75 | """ 76 | Obtained from: github.com:rwightman/pytorch-image-models 77 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 78 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 79 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 80 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 81 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 82 | 'survival rate' as the argument. 83 | """ 84 | if drop_prob == 0.0 or not training: 85 | return x 86 | keep_prob = 1 - drop_prob 87 | shape = [x.shape[0]] + [1] * (tf.rank(x).numpy() - 1) # work with diff dim tensors, not just 2D ConvNets 88 | random_tensor = keep_prob + tf.random.uniform(shape=shape, minval=0.0, maxval=1.0, dtype=x.dtype) 89 | random_tensor = tf.floor(random_tensor) # binarize 90 | x = tf.divide(x, keep_prob) * random_tensor 91 | return x 92 | 93 | class DropPath(Layer): 94 | """ 95 | Obtained from: github.com:rwightman/pytorch-image-models 96 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 97 | """ 98 | def __init__(self, drop_prob=None): 99 | super(DropPath, self).__init__() 100 | self.drop_prob = drop_prob 101 | 102 | def call(self, x, training=True): 103 | return drop_path(x, self.drop_prob, training=training) 104 | 105 | class Attention(Layer): 106 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 107 | super(Attention, self).__init__() 108 | 109 | self.num_heads = num_heads 110 | head_dim = dim // self.num_heads 111 | self.scale = head_dim ** -0.5 112 | 113 | self.to_qkv = nn.Dense(units=dim * 3, use_bias=False) 114 | self.attend = nn.Softmax() 115 | self.attn_drop = nn.Dropout(rate=attention_dropout) 116 | 117 | self.proj = [ 118 | nn.Dense(units=dim), 119 | nn.Dropout(rate=projection_dropout) 120 | ] 121 | 122 | self.proj = Sequential(self.proj) 123 | 124 | def call(self, x, training=True): 125 | qkv = self.to_qkv(x) 126 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 127 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv) 128 | 129 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 130 | 131 | attn = self.attend(dots) 132 | attn = self.attn_drop(attn, training=training) 133 | 134 | x = tf.matmul(attn, v) 135 | x = rearrange(x, 'b h n d -> b n (h d)') 136 | x = self.proj(x, training=training) 137 | return x 138 | 139 | class TransformerEncoderLayer(Layer): 140 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 141 | attention_dropout=0.1, drop_path_rate=0.1): 142 | super(TransformerEncoderLayer, self).__init__() 143 | 144 | self.pre_norm = nn.LayerNormalization() 145 | self.self_attn = Attention(dim=d_model, num_heads=nhead, attention_dropout=attention_dropout, projection_dropout=dropout) 146 | 147 | self.linear1 = nn.Dense(units=dim_feedforward) 148 | self.dropout1 = nn.Dropout(rate=dropout) 149 | self.norm1 = nn.LayerNormalization() 150 | self.linear2 = nn.Dense(units=d_model) 151 | self.dropout2 = nn.Dropout(rate=dropout) 152 | self.drop_path_rate = drop_path_rate 153 | 154 | if drop_path_rate > 0: 155 | self.drop_path = DropPath(drop_path_rate) 156 | 157 | self.activation = GELU() 158 | 159 | def call(self, src, training=True): 160 | if self.drop_path_rate > 0.0: 161 | src = src + self.drop_path(self.self_attn(self.pre_norm(src))) 162 | else: 163 | src = src + self.self_attn(self.pre_norm(src)) 164 | 165 | src = self.norm1(src) 166 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)), training=training)) 167 | src2 = self.dropout2(src2, training=training) 168 | 169 | if self.drop_path_rate > 0.0: 170 | src = src + self.drop_path(src2, training=training) 171 | else: 172 | src = src + src2 173 | 174 | return src 175 | 176 | class Tokenizer(Layer): 177 | def __init__(self, 178 | kernel_size, stride, 179 | pooling_kernel_size=3, pooling_stride=2, 180 | n_conv_layers=1, 181 | n_output_channels=64, 182 | in_planes=64, 183 | activation=None, 184 | max_pool=True, 185 | conv_bias=False): 186 | super(Tokenizer, self).__init__() 187 | 188 | conv_layers = [] 189 | 190 | for i in range(n_conv_layers): 191 | if i == n_conv_layers-1: 192 | channels = n_output_channels 193 | else: 194 | channels = in_planes 195 | 196 | conv_layers += [nn.Conv2D(filters=channels, kernel_size=kernel_size, strides=stride, padding='SAME', use_bias=conv_bias)] 197 | if activation is not None: 198 | conv_layers += [activation()] 199 | if max_pool: 200 | conv_layers += [nn.MaxPool2D(pool_size=pooling_kernel_size, strides=pooling_stride, padding='SAME')] 201 | 202 | self.conv_layers = Sequential(conv_layers) 203 | 204 | def sequence_length(self, n_channels=3, height=224, width=224): 205 | x = tf.zeros(shape=[1, height, width, n_channels]) 206 | x = self.call(x) 207 | x = x.shape[1] 208 | 209 | return x 210 | 211 | def call(self, x, **kwargs): 212 | x = self.conv_layers(x) 213 | x = tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 214 | 215 | return x 216 | 217 | class TransformerClassifier(Layer): 218 | def __init__(self, 219 | seq_pool=True, 220 | embedding_dim=768, 221 | num_layers=12, 222 | num_heads=12, 223 | mlp_ratio=4.0, 224 | num_classes=1000, 225 | dropout_rate=0.1, 226 | attention_dropout=0.1, 227 | stochastic_depth_rate=0.1, 228 | positional_embedding='sine', 229 | sequence_length=None, 230 | *args, **kwargs): 231 | super(TransformerClassifier, self).__init__() 232 | 233 | positional_embedding = positional_embedding if \ 234 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 235 | dim_feedforward = int(embedding_dim * mlp_ratio) 236 | self.embedding_dim = embedding_dim 237 | self.sequence_length = sequence_length 238 | self.seq_pool = seq_pool 239 | 240 | assert sequence_length is not None or positional_embedding == 'none', \ 241 | f"Positional embedding is set to {positional_embedding} and" \ 242 | f" the sequence length was not specified." 243 | 244 | if not seq_pool: 245 | sequence_length += 1 246 | self.class_emb = tf.Variable(tf.zeros([1, 1, self.embedding_dim])) 247 | else: 248 | self.attention_pool = nn.Dense(units=1) 249 | 250 | if positional_embedding != 'none': 251 | if positional_embedding == 'learnable': 252 | self.positional_emb = tf.Variable(tf.random.truncated_normal(shape=[1, sequence_length, embedding_dim], stddev=0.2)) 253 | else: 254 | self.positional_emb = tf.Variable(self.sinusoidal_embedding(sequence_length, embedding_dim), trainable=False) 255 | else: 256 | self.positional_emb = None 257 | 258 | self.dropout = nn.Dropout(rate=dropout_rate) 259 | dpr = [x.numpy() for x in tf.linspace(0.0, stochastic_depth_rate, num_layers)] 260 | 261 | self.blocks = Sequential() 262 | for i in range(num_layers): 263 | self.blocks.add(TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 264 | dim_feedforward=dim_feedforward, dropout=dropout_rate, 265 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])) 266 | self.norm = nn.LayerNormalization() 267 | self.fc = nn.Dense(units=num_classes) 268 | 269 | def sinusoidal_embedding(self, n_channels, dim): 270 | pe = tf.cast(([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]), tf.float32) 271 | pe[:, 0::2] = tf.sin(pe[:, 0::2]) 272 | pe[:, 1::2] = tf.cos(pe[:, 1::2]) 273 | pe = tf.expand_dims(pe, axis=0) 274 | 275 | return pe 276 | 277 | def call(self, x, training=True): 278 | if self.positional_emb is None and x.shape[1] < self.sequence_length : 279 | 280 | x = tf.pad(x, [[0, 0], [0, self.sequence_length - x.shape[1]], [0, 0]]) 281 | if not self.seq_pool: 282 | cls_token = tf.tile(self.class_emb, multiples=[x.shape[0], 1, 1]) 283 | x = tf.concat([cls_token, x], axis=1) 284 | 285 | if self.positional_emb is not None: 286 | x += self.positional_emb 287 | 288 | x = self.dropout(x, training=training) 289 | 290 | x = self.blocks(x, training=training) 291 | x = self.norm(x) 292 | 293 | if self.seq_pool: 294 | x_init = x 295 | x = self.attention_pool(x) 296 | x = tf.nn.softmax(x, axis=1) 297 | x = tf.transpose(x, perm=[0, 2, 1]) 298 | x = tf.matmul(x, x_init) 299 | x = tf.squeeze(x, axis=1) 300 | else: 301 | x = x[:, 0] 302 | 303 | x = self.fc(x) 304 | 305 | return x 306 | 307 | class CCT(Model): 308 | def __init__(self, 309 | img_size=224, 310 | embedding_dim=768, 311 | n_input_channels=3, 312 | n_conv_layers=1, 313 | kernel_size=7, 314 | stride=2, 315 | pooling_kernel_size=3, 316 | pooling_stride=2, 317 | *args, **kwargs): 318 | super(CCT, self).__init__() 319 | img_height, img_width = pair(img_size) 320 | self.tokenizer = Tokenizer(n_output_channels=embedding_dim, 321 | kernel_size=kernel_size, 322 | stride=stride, 323 | pooling_kernel_size=pooling_kernel_size, 324 | pooling_stride=pooling_stride, 325 | max_pool=True, 326 | activation=nn.ReLU, 327 | n_conv_layers=n_conv_layers, 328 | conv_bias=False) 329 | 330 | self.classifier = TransformerClassifier( 331 | sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels, 332 | height=img_height, 333 | width=img_width), 334 | embedding_dim=embedding_dim, 335 | seq_pool=True, 336 | dropout_rate=0., 337 | attention_dropout=0.1, 338 | stochastic_depth_rate=0.1, 339 | *args, **kwargs) 340 | 341 | 342 | def call(self, img, training=None, **kwargs): 343 | x = self.tokenizer(img, training=training) 344 | x = self.classifier(x, training=training) 345 | return x 346 | 347 | """ Usage 348 | v = CCT( 349 | img_size = (224, 448), 350 | embedding_dim = 384, 351 | n_conv_layers = 2, 352 | kernel_size = 7, 353 | stride = 2, 354 | padding = 3, 355 | pooling_kernel_size = 3, 356 | pooling_stride = 2, 357 | pooling_padding = 1, 358 | num_layers = 14, 359 | num_heads = 6, 360 | mlp_radio = 3., 361 | num_classes = 1000, 362 | positional_embedding = 'learnable', # ['sine', 'learnable', 'none'] 363 | ) 364 | 365 | # cct = cct_2( 366 | # img_size = 224, 367 | # n_conv_layers = 1, 368 | # kernel_size = 7, 369 | # stride = 2, 370 | # padding = 3, 371 | # pooling_kernel_size = 3, 372 | # pooling_stride = 2, 373 | # pooling_padding = 1, 374 | # num_classes = 1000, 375 | # positional_embedding = 'learnable', # ['sine', 'learnable', 'none'] 376 | # ) 377 | 378 | img = tf.random.normal(shape=[5, 224, 224, 3]) 379 | preds = v(img) # (1, 1000) 380 | """ -------------------------------------------------------------------------------- /vit_tensorflow/cross_vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import einsum 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.layers import Layer 5 | from tensorflow.keras import Sequential 6 | import tensorflow.keras.layers as nn 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.tensorflow import Rearrange 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | class PreNorm(Layer): 18 | def __init__(self, fn): 19 | super(PreNorm, self).__init__() 20 | 21 | self.norm = nn.LayerNormalization() 22 | self.fn = fn 23 | 24 | def call(self, x, training=True, **kwargs): 25 | return self.fn(self.norm(x), training=training, **kwargs) 26 | 27 | class MLP(Layer): 28 | def __init__(self, dim, hidden_dim, dropout=0.0): 29 | super(MLP, self).__init__() 30 | def GELU(): 31 | def gelu(x, approximate=False): 32 | if approximate: 33 | coeff = tf.cast(0.044715, x.dtype) 34 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 35 | else: 36 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 37 | 38 | return nn.Activation(gelu) 39 | 40 | self.net = [ 41 | nn.Dense(units=hidden_dim), 42 | GELU(), 43 | nn.Dropout(rate=dropout), 44 | nn.Dense(units=dim), 45 | nn.Dropout(rate=dropout) 46 | ] 47 | self.net = Sequential(self.net) 48 | 49 | def call(self, x, training=True): 50 | return self.net(x, training=training) 51 | 52 | class Attention(Layer): 53 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 54 | super(Attention, self).__init__() 55 | inner_dim = dim_head * heads 56 | 57 | self.heads = heads 58 | self.scale = dim_head ** -0.5 59 | 60 | self.attend = nn.Softmax() 61 | self.to_q = nn.Dense(units=inner_dim, use_bias=False) 62 | self.to_kv = nn.Dense(units=inner_dim * 2, use_bias=False) 63 | 64 | self.to_out = [ 65 | nn.Dense(units=dim), 66 | nn.Dropout(rate=dropout) 67 | ] 68 | 69 | self.to_out = Sequential(self.to_out) 70 | 71 | def call(self, x, context=None, kv_include_self=False, training=True): 72 | 73 | context = default(context, x) 74 | 75 | if kv_include_self: 76 | context = tf.concat([x, context], axis=1) # cross attention requires CLS token includes itself as key / value 77 | 78 | q = self.to_q(x) 79 | kv = self.to_kv(context) 80 | k, v = tf.split(kv, num_or_size_splits=2, axis=-1) 81 | qkv = (q, k, v) 82 | 83 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 84 | 85 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 86 | 87 | attn = self.attend(dots) 88 | 89 | x = einsum('b h i j, b h j d -> b h i d', attn, v) 90 | x = rearrange(x, 'b h n d -> b n (h d)') 91 | x = self.to_out(x, training=training) 92 | 93 | return x 94 | 95 | class Transformer(Layer): 96 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 97 | super(Transformer, self).__init__() 98 | 99 | self.layers = [] 100 | self.norm = nn.LayerNormalization() 101 | 102 | for _ in range(depth): 103 | self.layers.append([ 104 | PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 105 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 106 | ]) 107 | 108 | def call(self, x, training=True): 109 | for attn, mlp in self.layers: 110 | x = attn(x, training=training) + x 111 | x = mlp(x, training=training) + x 112 | 113 | x = self.norm(x) 114 | 115 | return x 116 | 117 | # projecting CLS tokens, in the case that small and large patch tokens have different dimensions 118 | class ProjectInOut(Layer): 119 | def __init__(self, dim_in, dim_out, fn): 120 | super(ProjectInOut, self).__init__() 121 | self.fn = fn 122 | 123 | self.need_projection = dim_in != dim_out 124 | if self.need_projection: 125 | self.project_in = nn.Dense(units=dim_out) 126 | self.project_out = nn.Dense(units=dim_in) 127 | 128 | def call(self, x, training=True, *args, **kwargs): 129 | # args check 130 | if self.need_projection: 131 | x = self.project_in(x) 132 | 133 | x = self.fn(x, training=training, *args, **kwargs) 134 | 135 | if self.need_projection: 136 | x = self.project_out(x) 137 | 138 | return x 139 | 140 | # cross attention transformer 141 | class CrossTransformer(Layer): 142 | def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout): 143 | super(CrossTransformer, self).__init__() 144 | 145 | self.layers = [] 146 | 147 | for _ in range(depth): 148 | self.layers.append([ProjectInOut(sm_dim, lg_dim, PreNorm(Attention(lg_dim, heads=heads, dim_head=dim_head, dropout=dropout))), 149 | ProjectInOut(lg_dim, sm_dim, PreNorm(Attention(sm_dim, heads=heads, dim_head=dim_head, dropout=dropout)))] 150 | ) 151 | 152 | def call(self, inputs, training=True): 153 | sm_tokens, lg_tokens = inputs 154 | (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens)) 155 | 156 | for sm_attend_lg, lg_attend_sm in self.layers: 157 | sm_cls = sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True, training=training) + sm_cls 158 | lg_cls = lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True, training=training) + lg_cls 159 | 160 | sm_tokens = tf.concat([sm_cls, sm_patch_tokens], axis=1) 161 | lg_tokens = tf.concat([lg_cls, lg_patch_tokens], axis=1) 162 | 163 | return sm_tokens, lg_tokens 164 | 165 | # multi-scale encoder 166 | class MultiScaleEncoder(Layer): 167 | def __init__(self, 168 | depth, 169 | sm_dim, 170 | lg_dim, 171 | sm_enc_params, 172 | lg_enc_params, 173 | cross_attn_heads, 174 | cross_attn_depth, 175 | cross_attn_dim_head=64, 176 | dropout=0.0): 177 | super(MultiScaleEncoder, self).__init__() 178 | 179 | self.layers = [] 180 | 181 | for _ in range(depth): 182 | self.layers.append([Transformer(dim=sm_dim, dropout=dropout, **sm_enc_params), 183 | Transformer(dim=lg_dim, dropout=dropout, **lg_enc_params), 184 | CrossTransformer(sm_dim=sm_dim, lg_dim=lg_dim, 185 | depth=cross_attn_depth, heads=cross_attn_heads, dim_head=cross_attn_dim_head, dropout=dropout) 186 | ] 187 | ) 188 | 189 | 190 | def call(self, inputs, training=True): 191 | sm_tokens, lg_tokens = inputs 192 | for sm_enc, lg_enc, cross_attend in self.layers: 193 | sm_tokens, lg_tokens = sm_enc(sm_tokens, training=training), lg_enc(lg_tokens, training=training) 194 | sm_tokens, lg_tokens = cross_attend([sm_tokens, lg_tokens], training=training) 195 | 196 | return sm_tokens, lg_tokens 197 | 198 | # patch-based image to token embedder 199 | class ImageEmbedder(Layer): 200 | def __init__(self, 201 | dim, 202 | image_size, 203 | patch_size, 204 | dropout=0.0): 205 | super(ImageEmbedder, self).__init__() 206 | 207 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 208 | num_patches = (image_size // patch_size) ** 2 209 | 210 | self.patch_embedding = Sequential([ 211 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 212 | nn.Dense(units=dim) 213 | ], name='patch_embedding') 214 | 215 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 216 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 217 | self.dropout = nn.Dropout(rate=dropout) 218 | 219 | def call(self, x, training=True): 220 | x = self.patch_embedding(x) 221 | 222 | b, n, d = x.shape 223 | 224 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 225 | x = tf.concat([cls_tokens, x], axis=1) 226 | x += self.pos_embedding[:, :(n + 1)] 227 | x = self.dropout(x, training=training) 228 | 229 | return x 230 | 231 | # cross ViT class 232 | class CrossViT(Model): 233 | def __init__(self, 234 | image_size, 235 | num_classes, 236 | sm_dim, 237 | lg_dim, 238 | sm_patch_size=12, 239 | sm_enc_depth=1, 240 | sm_enc_heads=8, 241 | sm_enc_mlp_dim=2048, 242 | sm_enc_dim_head=64, 243 | lg_patch_size=16, 244 | lg_enc_depth=4, 245 | lg_enc_heads=8, 246 | lg_enc_mlp_dim=2048, 247 | lg_enc_dim_head=64, 248 | cross_attn_depth=2, 249 | cross_attn_heads=8, 250 | cross_attn_dim_head=64, 251 | depth=3, 252 | dropout=0.1, 253 | emb_dropout=0.1): 254 | super(CrossViT, self).__init__() 255 | self.sm_image_embedder = ImageEmbedder(dim=sm_dim, image_size=image_size, patch_size=sm_patch_size, dropout=emb_dropout) 256 | self.lg_image_embedder = ImageEmbedder(dim=lg_dim, image_size=image_size, patch_size=lg_patch_size, dropout=emb_dropout) 257 | 258 | self.multi_scale_encoder = MultiScaleEncoder( 259 | depth=depth, 260 | sm_dim=sm_dim, 261 | lg_dim=lg_dim, 262 | cross_attn_heads=cross_attn_heads, 263 | cross_attn_dim_head=cross_attn_dim_head, 264 | cross_attn_depth=cross_attn_depth, 265 | sm_enc_params=dict( 266 | depth=sm_enc_depth, 267 | heads=sm_enc_heads, 268 | mlp_dim=sm_enc_mlp_dim, 269 | dim_head=sm_enc_dim_head 270 | ), 271 | lg_enc_params=dict( 272 | depth=lg_enc_depth, 273 | heads=lg_enc_heads, 274 | mlp_dim=lg_enc_mlp_dim, 275 | dim_head=lg_enc_dim_head 276 | ), 277 | dropout=dropout 278 | ) 279 | 280 | self.sm_mlp_head = Sequential([ 281 | nn.LayerNormalization(), 282 | nn.Dense(units=num_classes) 283 | ], name='sm_mlp_head') 284 | 285 | self.lg_mlp_head = Sequential([ 286 | nn.LayerNormalization(), 287 | nn.Dense(units=num_classes) 288 | ], name='lg_mlp_head') 289 | 290 | def call(self, img, training=True, **kwargs): 291 | sm_tokens = self.sm_image_embedder(img, training=training) 292 | lg_tokens = self.lg_image_embedder(img, training=training) 293 | 294 | sm_tokens, lg_tokens = self.multi_scale_encoder([sm_tokens, lg_tokens], training=training) 295 | 296 | sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens)) 297 | 298 | sm_logits = self.sm_mlp_head(sm_cls) 299 | lg_logits = self.lg_mlp_head(lg_cls) 300 | 301 | x = sm_logits + lg_logits 302 | 303 | return x 304 | 305 | """ Usage 306 | v = CrossViT( 307 | image_size = 256, 308 | num_classes = 1000, 309 | depth = 4, # number of multi-scale encoding blocks 310 | sm_dim = 192, # high res dimension 311 | sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size) 312 | sm_enc_depth = 2, # high res depth 313 | sm_enc_heads = 8, # high res heads 314 | sm_enc_mlp_dim = 2048, # high res feedforward dimension 315 | lg_dim = 384, # low res dimension 316 | lg_patch_size = 64, # low res patch size 317 | lg_enc_depth = 3, # low res depth 318 | lg_enc_heads = 8, # low res heads 319 | lg_enc_mlp_dim = 2048, # low res feedforward dimensions 320 | cross_attn_depth = 2, # cross attention rounds 321 | cross_attn_heads = 8, # cross attention heads 322 | dropout = 0.1, 323 | emb_dropout = 0.1 324 | ) 325 | 326 | img = tf.random.normal(shape=[1, 256, 256, 3]) 327 | preds = v(img) # (1, 1000) 328 | """ -------------------------------------------------------------------------------- /vit_tensorflow/crossformer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow import einsum 7 | 8 | from einops import rearrange 9 | from einops.layers.tensorflow import Rearrange, Reduce 10 | 11 | def cast_tuple(val, length = 1): 12 | return val if isinstance(val, tuple) else ((val,) * length) 13 | 14 | def gelu(x, approximate=False): 15 | if approximate: 16 | coeff = tf.cast(0.044715, x.dtype) 17 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 18 | else: 19 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 20 | 21 | class GELU(Layer): 22 | def __init__(self, approximate=False): 23 | super(GELU, self).__init__() 24 | self.approximate = approximate 25 | 26 | def call(self, x, training=True): 27 | return gelu(x, self.approximate) 28 | 29 | # cross embed layer 30 | class CrossEmbedLayer(Layer): 31 | def __init__(self, dim, kernel_sizes, stride=2): 32 | super(CrossEmbedLayer, self).__init__() 33 | 34 | kernel_sizes = sorted(kernel_sizes) 35 | num_scales = len(kernel_sizes) 36 | 37 | # calculate the dimension at each scale 38 | dim_scales = [int(dim / (2 ** i)) for i in range(1, num_scales)] 39 | dim_scales = [*dim_scales, dim - sum(dim_scales)] 40 | 41 | self.convs = [] 42 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 43 | self.convs.append(nn.Conv2D(filters=dim_scale, kernel_size=kernel, strides=stride, padding='SAME')) 44 | 45 | def call(self, x, training=True): 46 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 47 | x = tf.concat(fmaps, axis=-1) 48 | return x 49 | 50 | # dynamic positional bias 51 | class DynamicPositionBias(Layer): 52 | def __init__(self, dim): 53 | super(DynamicPositionBias, self).__init__() 54 | 55 | self.dpb_layers = Sequential([ 56 | nn.Dense(units=dim), 57 | nn.LayerNormalization(), 58 | nn.ReLU(), 59 | nn.Dense(units=dim), 60 | nn.LayerNormalization(), 61 | nn.ReLU(), 62 | nn.Dense(units=dim), 63 | nn.LayerNormalization(), 64 | nn.ReLU(), 65 | nn.Dense(units=1), 66 | Rearrange('... () -> ...') 67 | ]) 68 | 69 | def call(self, x, training=True): 70 | x = self.dpb_layers(x) 71 | return x 72 | 73 | # transformer classes 74 | class LayerNorm(Layer): 75 | def __init__(self, dim, eps=1e-5): 76 | super(LayerNorm, self).__init__() 77 | self.eps = eps 78 | 79 | self.g = tf.Variable(tf.ones([1, 1, 1, dim])) 80 | self.b = tf.Variable(tf.zeros([1, 1, 1, dim])) 81 | 82 | def call(self, x, training=True): 83 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 84 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 85 | 86 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b 87 | return x 88 | 89 | class MLP(Layer): 90 | def __init__(self, dim, mult=4, dropout=0.0): 91 | super(MLP, self).__init__() 92 | 93 | self.net = Sequential([ 94 | LayerNorm(dim), 95 | nn.Conv2D(filters=dim*mult, kernel_size=1, strides=1), 96 | GELU(), 97 | nn.Dropout(rate=dropout), 98 | nn.Conv2D(filters=dim, kernel_size=1, strides=1) 99 | ]) 100 | 101 | def call(self, x, training=True): 102 | return self.net(x, training=training) 103 | 104 | class Attention(Layer): 105 | def __init__(self, dim, attn_type, window_size, dim_head=32, dropout=0.0): 106 | super(Attention, self).__init__() 107 | 108 | assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant' 109 | heads = dim // dim_head 110 | self.heads = heads 111 | self.scale = dim_head ** -0.5 112 | inner_dim = dim_head * heads 113 | 114 | self.attn_type = attn_type 115 | self.window_size = window_size 116 | 117 | self.norm = LayerNorm(dim) 118 | self.to_qkv = nn.Conv2D(filters=inner_dim * 3, kernel_size=1, strides=1, use_bias=False) 119 | self.to_out = nn.Conv2D(filters=dim, kernel_size=1, strides=1) 120 | 121 | # positions 122 | self.dpb = DynamicPositionBias(dim // 4) 123 | self.attend = nn.Softmax() 124 | 125 | # calculate and store indices for retrieving bias 126 | pos = tf.range(window_size) 127 | grid = tf.stack(tf.meshgrid(pos, pos, indexing='ij')) 128 | grid = rearrange(grid, 'c i j -> (i j) c') 129 | rel_pos = grid[:, None] - grid[None, :] 130 | rel_pos += window_size - 1 131 | self.rel_pos_indices = tf.reduce_sum(rel_pos * tf.convert_to_tensor([2 * window_size - 1, 1]), axis=-1) 132 | 133 | def call(self, x, training=True): 134 | _, height, width, _ = x.shape 135 | heads = self.heads 136 | wsz = self.window_size 137 | 138 | # prenorm 139 | x = self.norm(x) 140 | 141 | # rearrange for short or long distance attention 142 | 143 | if self.attn_type == 'short': 144 | x = rearrange(x, 'b (h s1) (w s2) d -> (b h w) s1 s2 d', s1=wsz, s2=wsz) 145 | elif self.attn_type == 'long': 146 | x = rearrange(x, 'b (l1 h) (l2 w) d -> (b h w) l1 l2 d', l1=wsz, l2=wsz) 147 | 148 | # queries / keys / values 149 | qkv = self.to_qkv(x) 150 | q, k, v = tf.split(qkv, num_or_size_splits=3, axis=-1) 151 | 152 | # split heads 153 | q, k, v = map(lambda t: rearrange(t, 'b x y (h d) -> b h (x y) d', h=heads), (q, k, v)) 154 | q = q * self.scale 155 | 156 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 157 | 158 | # add dynamic positional bias 159 | pos = tf.range(-wsz, wsz + 1) 160 | rel_pos = tf.stack(tf.meshgrid(pos, pos, indexing='ij')) 161 | rel_pos = rearrange(rel_pos, 'c i j -> (i j) c') 162 | biases = self.dpb(tf.cast(rel_pos, tf.float32)) 163 | rel_pos_bias = biases.numpy()[self.rel_pos_indices.numpy()] 164 | 165 | sim = sim + rel_pos_bias 166 | 167 | # attend 168 | attn = self.attend(sim) 169 | 170 | # merge heads 171 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 172 | out = rearrange(out, 'b h (x y) d -> b x y (h d) ', x=wsz, y=wsz) 173 | out = self.to_out(out) 174 | # rearrange back for long or short distance attention 175 | if self.attn_type == 'short': 176 | out = rearrange(out, '(b h w) s1 s2 d -> b (h s1) (w s2) d', h=height // wsz, w=width // wsz) 177 | elif self.attn_type == 'long': 178 | out = rearrange(out, '(b h w) l1 l2 d -> b (l1 h) (l2 w) d', h=height // wsz, w=width // wsz) 179 | 180 | return out 181 | 182 | class Transformer(Layer): 183 | def __init__(self, dim, local_window_size, global_window_size, depth=4, dim_head=32, attn_dropout=0.0, ff_dropout=0.0): 184 | super(Transformer, self).__init__() 185 | 186 | self.layers = [] 187 | 188 | for _ in range(depth): 189 | self.layers.append([ 190 | Attention(dim, attn_type='short', window_size=local_window_size, dim_head=dim_head, dropout=attn_dropout), 191 | MLP(dim, dropout=ff_dropout), 192 | Attention(dim, attn_type='long', window_size=global_window_size, dim_head=dim_head, dropout=attn_dropout), 193 | MLP(dim, dropout=ff_dropout) 194 | ]) 195 | 196 | def call(self, x, training=True): 197 | for short_attn, short_ff, long_attn, long_ff in self.layers: 198 | x = short_attn(x) + x 199 | x = short_ff(x, training=training) + x 200 | x = long_attn(x) + x 201 | x = long_ff(x, training=training) + x 202 | 203 | return x 204 | 205 | class CrossFormer(Model): 206 | def __init__(self, 207 | dim=(64, 128, 256, 512), 208 | depth=(2, 2, 8, 2), 209 | global_window_size=(8, 4, 2, 1), 210 | local_window_size=7, 211 | cross_embed_kernel_sizes=((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)), 212 | cross_embed_strides=(4, 2, 2, 2), 213 | num_classes=1000, 214 | attn_dropout=0.0, 215 | ff_dropout=0.0, 216 | ): 217 | super(CrossFormer, self).__init__() 218 | dim = cast_tuple(dim, 4) 219 | depth = cast_tuple(depth, 4) 220 | global_window_size = cast_tuple(global_window_size, 4) 221 | local_window_size = cast_tuple(local_window_size, 4) 222 | cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4) 223 | cross_embed_strides = cast_tuple(cross_embed_strides, 4) 224 | 225 | assert len(dim) == 4 226 | assert len(depth) == 4 227 | assert len(global_window_size) == 4 228 | assert len(local_window_size) == 4 229 | assert len(cross_embed_kernel_sizes) == 4 230 | assert len(cross_embed_strides) == 4 231 | 232 | # layers 233 | self.crossformer_layers = [] 234 | 235 | for dim_out, layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim, depth, 236 | global_window_size, local_window_size, 237 | cross_embed_kernel_sizes, cross_embed_strides): 238 | self.crossformer_layers.append([ 239 | CrossEmbedLayer(dim_out, cel_kernel_sizes, stride=cel_stride), 240 | Transformer(dim_out, local_window_size=local_wsz, global_window_size=global_wsz, depth=layers, 241 | attn_dropout=attn_dropout, ff_dropout=ff_dropout) 242 | ]) 243 | 244 | # final logits 245 | self.to_logits = Sequential([ 246 | Reduce('b h w c -> b c', 'mean'), 247 | nn.Dense(units=num_classes) 248 | ]) 249 | 250 | def call(self, x, training=True, **kwargs): 251 | for cel, transformer in self.crossformer_layers: 252 | x = cel(x) 253 | x = transformer(x, training=training) 254 | 255 | x = self.to_logits(x) 256 | 257 | return x 258 | """ Usage 259 | v = CrossFormer( 260 | num_classes = 1000, # number of output classes 261 | dim = (64, 128, 256, 512), # dimension at each stage 262 | depth = (2, 2, 8, 2), # depth of transformer at each stage 263 | global_window_size = (8, 4, 2, 1), # global window sizes at each stage 264 | local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages) 265 | ) 266 | 267 | img = tf.random.normal(shape=[1, 224, 224, 3]) 268 | preds = v(img) # (1, 1000) 269 | """ -------------------------------------------------------------------------------- /vit_tensorflow/cvt.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import einsum 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.layers import Layer 5 | from tensorflow.keras import Sequential 6 | import tensorflow.keras.layers as nn 7 | 8 | from einops import rearrange 9 | 10 | def group_dict_by_key(cond, d): 11 | return_val = [dict(), dict()] 12 | for key in d.keys(): 13 | match = bool(cond(key)) 14 | ind = int(not match) 15 | return_val[ind][key] = d[key] 16 | return (*return_val,) 17 | 18 | def group_by_key_prefix_and_remove_prefix(prefix, d): 19 | kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d) 20 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 21 | return kwargs_without_prefix, kwargs 22 | 23 | def gelu(x, approximate=False): 24 | if approximate: 25 | coeff = tf.cast(0.044715, x.dtype) 26 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 27 | else: 28 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 29 | 30 | class LayerNorm(Layer): # layernorm, but done in the channel dimension #1 31 | def __init__(self, dim, eps=1e-5): 32 | super(LayerNorm, self).__init__() 33 | self.eps = eps 34 | 35 | self.g = tf.Variable(tf.ones([1, 1, 1, dim])) 36 | self.b = tf.Variable(tf.zeros([1, 1, 1, dim])) 37 | 38 | def call(self, x, training=True): 39 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 40 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 41 | 42 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b 43 | return x 44 | 45 | class PreNorm(Layer): 46 | def __init__(self, dim, fn): 47 | super(PreNorm, self).__init__() 48 | 49 | self.norm = LayerNorm(dim) 50 | self.fn = fn 51 | 52 | def call(self, x, training=True): 53 | return self.fn(self.norm(x), training=training) 54 | 55 | class GELU(Layer): 56 | def __init__(self, approximate=False): 57 | super(GELU, self).__init__() 58 | self.approximate = approximate 59 | 60 | def call(self, x, training=True): 61 | return gelu(x, self.approximate) 62 | 63 | class MLP(Layer): 64 | def __init__(self, dim, mult=4, dropout=0.0): 65 | super(MLP, self).__init__() 66 | 67 | self.net = [ 68 | nn.Conv2D(filters=dim * mult, kernel_size=1, strides=1), 69 | GELU(), 70 | nn.Dropout(rate=dropout), 71 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 72 | nn.Dropout(rate=dropout) 73 | ] 74 | self.net = Sequential(self.net) 75 | 76 | def call(self, x, training=True): 77 | return self.net(x, training=training) 78 | 79 | class DepthWiseConv2d(Layer): 80 | def __init__(self, dim_in, dim_out, kernel_size, stride, bias=True): 81 | super(DepthWiseConv2d, self).__init__() 82 | 83 | net = [] 84 | net += [nn.Conv2D(filters=dim_in, kernel_size=kernel_size, strides=stride, padding='SAME', groups=dim_in, use_bias=bias)] 85 | net += [nn.BatchNormalization(momentum=0.9, epsilon=1e-5)] 86 | net += [nn.Conv2D(filters=dim_out, kernel_size=1, strides=1, use_bias=bias)] 87 | 88 | self.net = Sequential(net) 89 | 90 | def call(self, x, training=True): 91 | x = self.net(x, training=training) 92 | return x 93 | 94 | class Attention(Layer): 95 | def __init__(self, dim, proj_kernel, kv_proj_stride, heads=8, dim_head=64, dropout=0.0): 96 | super(Attention, self).__init__() 97 | inner_dim = dim_head * heads 98 | self.heads = heads 99 | self.scale = dim_head ** -0.5 100 | 101 | self.attend = nn.Softmax() 102 | 103 | self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, stride=1, bias=False) 104 | self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, stride=kv_proj_stride, bias=False) 105 | 106 | self.to_out = Sequential([ 107 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 108 | nn.Dropout(rate=dropout) 109 | ]) 110 | 111 | def call(self, x, training=True): 112 | b, _, y, n = x.shape 113 | h = self.heads 114 | q = self.to_q(x, training=training) 115 | kv = self.to_kv(x, training=training) 116 | k, v = tf.split(kv, num_or_size_splits=2, axis=-1) 117 | qkv = (q, k, v) 118 | q, k, v = map(lambda t: rearrange(t, 'b x y (h d) -> (b h) (x y) d', h=h), qkv) 119 | 120 | dots = einsum('b i d, b j d -> b i j', q, k) * self.scale 121 | attn = self.attend(dots) 122 | 123 | x = einsum('b i j, b j d -> b i d', attn, v) 124 | x = rearrange(x, '(b h) (x y) d -> b x y (h d)', h=h, y=y) 125 | x = self.to_out(x, training=training) 126 | 127 | return x 128 | 129 | class Transformer(Layer): 130 | def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head=64, mlp_mult=4, dropout=0.): 131 | super(Transformer, self).__init__() 132 | 133 | self.layers = [] 134 | 135 | for _ in range(depth): 136 | self.layers.append([ 137 | PreNorm(dim, Attention(dim, proj_kernel=proj_kernel, kv_proj_stride=kv_proj_stride, heads=heads, dim_head=dim_head, dropout=dropout)), 138 | PreNorm(dim, MLP(dim, mlp_mult, dropout=dropout)) 139 | ]) 140 | 141 | 142 | def call(self, x, training=True): 143 | for i, (attn, ff) in enumerate(self.layers): 144 | x = attn(x, training=training) + x 145 | x = ff(x, training=training) + x 146 | 147 | return x 148 | 149 | class CvT(Model): 150 | def __init__(self, 151 | num_classes, 152 | s1_emb_dim=64, 153 | s1_emb_kernel=7, 154 | s1_emb_stride=4, 155 | s1_proj_kernel=3, 156 | s1_kv_proj_stride=2, 157 | s1_heads=1, 158 | s1_depth=1, 159 | s1_mlp_mult=4, 160 | s2_emb_dim=192, 161 | s2_emb_kernel=3, 162 | s2_emb_stride=2, 163 | s2_proj_kernel=3, 164 | s2_kv_proj_stride=2, 165 | s2_heads=3, 166 | s2_depth=2, 167 | s2_mlp_mult=4, 168 | s3_emb_dim=384, 169 | s3_emb_kernel=3, 170 | s3_emb_stride=2, 171 | s3_proj_kernel=3, 172 | s3_kv_proj_stride=2, 173 | s3_heads=6, 174 | s3_depth=10, 175 | s3_mlp_mult=4, 176 | dropout=0. 177 | ): 178 | 179 | super(CvT, self).__init__() 180 | kwargs = dict(locals()) 181 | 182 | self.cvt_layers = Sequential() 183 | 184 | for prefix in ('s1', 's2', 's3'): 185 | config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs) 186 | self.cvt_layers.add(Sequential([ 187 | nn.Conv2D(filters=config['emb_dim'], kernel_size=config['emb_kernel'], padding='SAME', strides=config['emb_stride']), 188 | LayerNorm(config['emb_dim']), 189 | Transformer(dim=config['emb_dim'], proj_kernel=config['proj_kernel'], 190 | kv_proj_stride=config['kv_proj_stride'], depth=config['depth'], heads=config['heads'], 191 | mlp_mult=config['mlp_mult'], dropout=dropout) 192 | ])) 193 | 194 | 195 | self.cvt_layers.add(Sequential([ 196 | nn.GlobalAvgPool2D(), 197 | nn.Dense(units=num_classes) 198 | ])) 199 | 200 | def call(self, img, training=True, **kwargs): 201 | x = self.cvt_layers(img, training=training) 202 | return x 203 | 204 | """ Usage 205 | v = CvT( 206 | num_classes = 1000, 207 | s1_emb_dim = 64, # stage 1 - dimension 208 | s1_emb_kernel = 7, # stage 1 - conv kernel 209 | s1_emb_stride = 4, # stage 1 - conv stride 210 | s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size 211 | s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride 212 | s1_heads = 1, # stage 1 - heads 213 | s1_depth = 1, # stage 1 - depth 214 | s1_mlp_mult = 4, # stage 1 - feedforward expansion factor 215 | s2_emb_dim = 192, # stage 2 - (same as above) 216 | s2_emb_kernel = 3, 217 | s2_emb_stride = 2, 218 | s2_proj_kernel = 3, 219 | s2_kv_proj_stride = 2, 220 | s2_heads = 3, 221 | s2_depth = 2, 222 | s2_mlp_mult = 4, 223 | s3_emb_dim = 384, # stage 3 - (same as above) 224 | s3_emb_kernel = 3, 225 | s3_emb_stride = 2, 226 | s3_proj_kernel = 3, 227 | s3_kv_proj_stride = 2, 228 | s3_heads = 4, 229 | s3_depth = 10, 230 | s3_mlp_mult = 4, 231 | dropout = 0. 232 | ) 233 | 234 | img = tf.random.normal(shape=[1, 224, 224, 3]) 235 | preds = v(img) # (1, 1000) 236 | """ 237 | -------------------------------------------------------------------------------- /vit_tensorflow/deepvit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import einsum 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.layers import Layer 5 | from tensorflow.keras import Sequential 6 | import tensorflow.keras.layers as nn 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.tensorflow import Rearrange 10 | 11 | class PreNorm(Layer): 12 | def __init__(self, fn): 13 | super(PreNorm, self).__init__() 14 | 15 | self.norm = nn.LayerNormalization() 16 | self.fn = fn 17 | 18 | def call(self, x, training=True): 19 | return self.fn(self.norm(x), training=training) 20 | 21 | class MLP(Layer): 22 | def __init__(self, dim, hidden_dim, dropout=0.0): 23 | super(MLP, self).__init__() 24 | def GELU(): 25 | def gelu(x, approximate=False): 26 | if approximate: 27 | coeff = tf.cast(0.044715, x.dtype) 28 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 29 | else: 30 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 31 | 32 | return nn.Activation(gelu) 33 | 34 | self.net = [ 35 | nn.Dense(units=hidden_dim), 36 | GELU(), 37 | nn.Dropout(rate=dropout), 38 | nn.Dense(units=dim), 39 | nn.Dropout(rate=dropout) 40 | ] 41 | self.net = Sequential(self.net) 42 | 43 | def call(self, x, training=True): 44 | return self.net(x, training=training) 45 | 46 | class Attention(Layer): 47 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 48 | super(Attention, self).__init__() 49 | inner_dim = dim_head * heads 50 | 51 | self.heads = heads 52 | self.scale = dim_head ** -0.5 53 | 54 | self.attend = nn.Softmax() 55 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 56 | 57 | self.reattn_weights = tf.Variable(initial_value=tf.random.normal([heads, heads])) 58 | 59 | self.reattn_norm = [ 60 | Rearrange('b h i j -> b i j h'), 61 | nn.LayerNormalization(), 62 | Rearrange('b i j h -> b h i j') 63 | ] 64 | 65 | self.to_out = [ 66 | nn.Dense(units=dim), 67 | nn.Dropout(rate=dropout) 68 | ] 69 | 70 | self.reattn_norm = Sequential(self.reattn_norm) 71 | self.to_out = Sequential(self.to_out) 72 | 73 | def call(self, x, training=True): 74 | qkv = self.to_qkv(x) 75 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 76 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 77 | 78 | # attention 79 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 80 | attn = self.attend(dots) 81 | 82 | # re-attention 83 | attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) 84 | attn = self.reattn_norm(attn) 85 | 86 | # aggregate and out 87 | x = tf.matmul(attn, v) 88 | x = rearrange(x, 'b h n d -> b n (h d)') 89 | x = self.to_out(x, training=training) 90 | 91 | return x 92 | 93 | class Transformer(Layer): 94 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 95 | super(Transformer, self).__init__() 96 | 97 | self.layers = [] 98 | 99 | for _ in range(depth): 100 | self.layers.append([ 101 | PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 102 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 103 | ]) 104 | 105 | def call(self, x, training=True): 106 | for attn, mlp in self.layers: 107 | x = attn(x, training=training) + x 108 | x = mlp(x, training=training) + x 109 | 110 | return x 111 | 112 | class DeepViT(Model): 113 | def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, 114 | pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0): 115 | super(DeepViT, self).__init__() 116 | 117 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 118 | num_patches = (image_size // patch_size) ** 2 119 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 120 | 121 | self.patch_embedding = Sequential([ 122 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 123 | nn.Dense(units=dim) 124 | ], name='patch_embedding') 125 | 126 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 127 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 128 | self.dropout = nn.Dropout(rate=emb_dropout) 129 | 130 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 131 | 132 | self.pool = pool 133 | 134 | self.mlp_head = Sequential([ 135 | nn.LayerNormalization(), 136 | nn.Dense(units=num_classes) 137 | ], name='mlp_head') 138 | 139 | def call(self, img, training=True, **kwargs): 140 | x = self.patch_embedding(img) 141 | b, n, d = x.shape 142 | 143 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 144 | x = tf.concat([cls_tokens, x], axis=1) 145 | x += self.pos_embedding[:, :(n + 1)] 146 | x = self.dropout(x, training=training) 147 | 148 | x = self.transformer(x, training=training) 149 | 150 | if self.pool == 'mean': 151 | x = tf.reduce_mean(x, axis=1) 152 | else: 153 | x = x[:, 0] 154 | 155 | x = self.mlp_head(x) 156 | 157 | return x 158 | 159 | """ Usage 160 | 161 | v = DeepViT( 162 | image_size = 256, 163 | patch_size = 32, 164 | num_classes = 1000, 165 | dim = 1024, 166 | depth = 6, 167 | heads = 16, 168 | mlp_dim = 2048, 169 | dropout = 0.1, 170 | emb_dropout = 0.1 171 | ) 172 | 173 | img = tf.random.normal(shape=[1, 256, 256, 3]) 174 | preds = v(img) # (1, 1000) 175 | 176 | """ -------------------------------------------------------------------------------- /vit_tensorflow/distill.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange, repeat 8 | 9 | from vit import ViT 10 | from t2t import T2TViT 11 | from efficient import ViT as EfficientViT 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | class DistillMixin: 17 | def call(self, img, distill_token=None, training=True): 18 | distilling = exists(distill_token) 19 | x = self.patch_embedding(img) 20 | b, n, d = x.shape 21 | 22 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 23 | x = tf.concat([cls_tokens, x], axis=1) 24 | x += self.pos_embedding[:, :(n + 1)] 25 | 26 | if distilling: 27 | distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) 28 | x = tf.concat([x, distill_tokens], axis=1) 29 | 30 | x = self._attend(x, training=training) 31 | 32 | if distilling: 33 | x, distill_tokens = x[:, :-1], x[:, -1] 34 | 35 | if self.pool == 'mean': 36 | x = tf.reduce_mean(x, axis=1) 37 | else: 38 | x = x[:, 0] 39 | 40 | x = self.mlp_head(x) 41 | 42 | if distilling: 43 | return x, distill_tokens 44 | else: 45 | return x 46 | 47 | class DistillableViT(DistillMixin, ViT): 48 | def __init__(self, *args, **kwargs): 49 | super(DistillableViT, self).__init__(*args, **kwargs) 50 | self.args = args 51 | self.kwargs = kwargs 52 | self.dim = kwargs['dim'] 53 | self.num_classes = kwargs['num_classes'] 54 | 55 | def _attend(self, x, training=True): 56 | x = self.dropout(x, training=training) 57 | x = self.transformer(x, training=training) 58 | return x 59 | 60 | 61 | class DistillableT2TViT(DistillMixin, T2TViT): 62 | def __init__(self, *args, **kwargs): 63 | super(DistillableT2TViT, self).__init__(*args, **kwargs) 64 | self.args = args 65 | self.kwargs = kwargs 66 | self.dim = kwargs['dim'] 67 | self.num_classes = kwargs['num_classes'] 68 | 69 | def _attend(self, x, training=True): 70 | x = self.dropout(x, training=training) 71 | x = self.transformer(x, training=training) 72 | return x 73 | 74 | class DistillableEfficientViT(DistillMixin, EfficientViT): 75 | def __init__(self, *args, **kwargs): 76 | super(DistillableEfficientViT, self).__init__(*args, **kwargs) 77 | self.args = args 78 | self.kwargs = kwargs 79 | self.dim = kwargs['dim'] 80 | self.num_classes = kwargs['num_classes'] 81 | 82 | def _attend(self, x, training=True): 83 | x = self.dropout(x, training=training) 84 | x = self.transformer(x, training=training) 85 | return x 86 | 87 | class DistillWrapper(Model): 88 | def __init__(self, teacher, student, temperature=1.0, alpha=0.5, hard=False): 89 | super(DistillWrapper, self).__init__() 90 | 91 | assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))), 'student must be a vision transformer' 92 | 93 | self.teacher = teacher 94 | self.student = student 95 | 96 | dim = student.dim 97 | num_classes = student.num_classes 98 | self.temperature = temperature 99 | self.alpha = alpha 100 | self.hard = hard 101 | self.distillation_token = tf.Variable(tf.random.normal([1, 1, dim])) 102 | 103 | self.distill_mlp = Sequential([ 104 | nn.LayerNormalization(), 105 | nn.Dense(units=num_classes) 106 | ], name='distill_mlp') 107 | 108 | def call(self, inputs, temperature=None, alpha=None, training=True, **kwargs): 109 | img, labels = inputs 110 | b, *_ = img.shape 111 | alpha = alpha if exists(alpha) else self.alpha 112 | T = temperature if exists(temperature) else self.temperature 113 | 114 | teacher_logits = tf.stop_gradient(self.teacher(img, training=training)) 115 | 116 | student_logits, distill_tokens = self.student(img, distill_token=self.distillation_token, training=training) 117 | distill_logits = self.distill_mlp(distill_tokens) 118 | 119 | loss = tf.keras.losses.categorical_crossentropy(y_true=labels, y_pred=student_logits, from_logits=True) 120 | 121 | if not self.hard: 122 | x = tf.nn.log_softmax(distill_logits / T, axis=-1) 123 | y = tf.nn.softmax(teacher_logits / T, axis=-1) 124 | distill_loss = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.NONE)(y_true=y, y_pred=x) 125 | 126 | batch = distill_loss.shape[0] 127 | distill_loss = tf.reduce_sum(distill_loss) / batch 128 | 129 | distill_loss *= T ** 2 130 | else: 131 | teacher_labels = tf.argmax(teacher_logits, axis=-1) 132 | distill_loss = tf.keras.losses.categorical_crossentropy(y_true=teacher_labels, y_pred=distill_logits, from_logits=True) 133 | 134 | return loss * (1 - alpha) + distill_loss * alpha 135 | 136 | 137 | """ Usage 138 | teacher = tf.keras.applications.resnet50.ResNet50() 139 | 140 | v = DistillableViT( 141 | image_size = 256, 142 | patch_size = 32, 143 | num_classes = 1000, 144 | dim = 1024, 145 | depth = 6, 146 | heads = 8, 147 | mlp_dim = 2048, 148 | dropout = 0.1, 149 | emb_dropout = 0.1 150 | ) 151 | 152 | distiller = DistillWrapper( 153 | student = v, 154 | teacher = teacher, 155 | temperature = 3, # temperature of distillation 156 | alpha = 0.5, # trade between main loss and distillation loss 157 | hard = False # whether to use soft or hard distillation 158 | ) 159 | 160 | img = tf.random.normal([2, 256, 256, 3]) 161 | labels = tf.random.uniform(shape=[2, ], minval=0, maxval=1000, dtype=tf.int32) 162 | labels = tf.one_hot(labels, depth=1000, axis=-1) 163 | 164 | loss = distiller([img, labels]) 165 | """ 166 | 167 | -------------------------------------------------------------------------------- /vit_tensorflow/efficient.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.tensorflow import Rearrange 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | class ViT(Model): 14 | def __init__(self, image_size, patch_size, num_classes, dim, transformer, pool='cls'): 15 | super(ViT, self).__init__() 16 | 17 | image_size_h, image_size_w = pair(image_size) 18 | assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size' 19 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 20 | num_patches = (image_size_h // patch_size) * (image_size_w // patch_size) 21 | 22 | self.patch_embedding = Sequential([ 23 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 24 | nn.Dense(units=dim) 25 | ], name='patch_embedding') 26 | 27 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 28 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 29 | 30 | self.transformer = transformer 31 | 32 | self.pool = pool 33 | 34 | self.mlp_head = Sequential([ 35 | nn.LayerNormalization(), 36 | nn.Dense(units=num_classes) 37 | ], name='mlp_head') 38 | 39 | def call(self, img, training=True, **kwargs): 40 | x = self.patch_embedding(img) 41 | b, n, d = x.shape 42 | 43 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 44 | x = tf.concat([cls_tokens, x], axis=1) 45 | x += self.pos_embedding[:, :(n + 1)] 46 | x = self.transformer(x, training=training) 47 | 48 | if self.pool == 'mean': 49 | x = tf.reduce_mean(x, axis=1) 50 | else: 51 | x = x[:, 0] 52 | 53 | x = self.mlp_head(x) 54 | 55 | return x -------------------------------------------------------------------------------- /vit_tensorflow/levit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from tensorflow import einsum 8 | from einops import rearrange 9 | 10 | from math import ceil 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | def cast_tuple(val, l = 3): 19 | val = val if isinstance(val, tuple) else (val,) 20 | return (*val, *((val[-1],) * max(l - len(val), 0))) 21 | 22 | def always(val): 23 | return lambda *args, **kwargs: val 24 | 25 | def gelu(x, approximate=False): 26 | if approximate: 27 | coeff = tf.cast(0.044715, x.dtype) 28 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 29 | else: 30 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 31 | 32 | class HardSwish(Layer): 33 | def __init__(self): 34 | super(HardSwish, self).__init__() 35 | 36 | def call(self, x, training=True): 37 | x = x * tf.nn.relu6(x + 3.0) / 6.0 38 | return x 39 | 40 | class GELU(Layer): 41 | def __init__(self, approximate=False): 42 | super(GELU, self).__init__() 43 | self.approximate = approximate 44 | 45 | def call(self, x, training=True): 46 | return gelu(x, self.approximate) 47 | 48 | class MLP(Layer): 49 | def __init__(self, dim, mult, dropout=0.0): 50 | super(MLP, self).__init__() 51 | 52 | self.net = [ 53 | nn.Conv2D(filters=dim * mult, kernel_size=1, strides=1), 54 | HardSwish(), 55 | nn.Dropout(rate=dropout), 56 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 57 | nn.Dropout(rate=dropout) 58 | ] 59 | self.net = Sequential(self.net) 60 | 61 | def call(self, x, training=True): 62 | return self.net(x, training=training) 63 | 64 | class Attention(Layer): 65 | def __init__(self, dim, fmap_size, heads=8, dim_key=32, dim_value=64, dropout=0.0, dim_out=None, downsample=False): 66 | super(Attention, self).__init__() 67 | inner_dim_key = dim_key * heads 68 | inner_dim_value = dim_value * heads 69 | dim_out = default(dim_out, dim) 70 | 71 | self.heads = heads 72 | self.scale = dim_key ** -0.5 73 | 74 | self.to_q = Sequential([ 75 | nn.Conv2D(filters=inner_dim_key, kernel_size=1, strides=(2 if downsample else 1), use_bias=False), 76 | nn.BatchNormalization(momentum=0.9, epsilon=1e-05), 77 | ]) 78 | 79 | self.to_k = Sequential([ 80 | nn.Conv2D(filters=inner_dim_key, kernel_size=1, strides=1, use_bias=False), 81 | nn.BatchNormalization(momentum=0.9, epsilon=1e-05), 82 | ]) 83 | 84 | self.to_v = Sequential([ 85 | nn.Conv2D(filters=inner_dim_value, kernel_size=1, strides=1, use_bias=False), 86 | nn.BatchNormalization(momentum=0.9, epsilon=1e-05), 87 | ]) 88 | 89 | self.attend = nn.Softmax() 90 | 91 | out_batch_norm = nn.BatchNormalization(momentum=0.9, epsilon=1e-05, gamma_initializer='zeros') 92 | 93 | self.to_out = Sequential([ 94 | GELU(), 95 | nn.Conv2D(filters=dim_out, kernel_size=1, strides=1), 96 | out_batch_norm, 97 | nn.Dropout(rate=dropout) 98 | ]) 99 | 100 | # positional bias 101 | self.pos_bias = nn.Embedding(input_dim=fmap_size * fmap_size, output_dim=heads) 102 | q_range = tf.range(0, fmap_size, delta=(2 if downsample else 1)) 103 | k_range = tf.range(fmap_size) 104 | 105 | q_pos = tf.stack(tf.meshgrid(q_range, q_range, indexing='ij'), axis=-1) 106 | k_pos = tf.stack(tf.meshgrid(k_range, k_range, indexing='ij'), axis=-1) 107 | 108 | q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos)) 109 | rel_pos = tf.abs((q_pos[:, None, ...] - k_pos[None, :, ...])) 110 | 111 | x_rel, y_rel = tf.unstack(rel_pos, axis=-1) 112 | self.pos_indices = (x_rel * fmap_size) + y_rel 113 | 114 | def apply_pos_bias(self, fmap): 115 | bias = self.pos_bias(self.pos_indices) 116 | bias = rearrange(bias, 'i j h -> () h i j') 117 | return fmap + (bias / self.scale) 118 | 119 | def call(self, x, training=True): 120 | b, height, width, n = x.shape 121 | q = self.to_q(x) 122 | 123 | h = self.heads 124 | y = q.shape[1] # height 125 | 126 | qkv = (q, self.to_k(x), self.to_v(x)) 127 | q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> b h (...) d', h=h), qkv) 128 | 129 | # i,j = height*width 130 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 131 | dots = self.apply_pos_bias(dots) 132 | 133 | attn = self.attend(dots) 134 | 135 | x = einsum('b h i j, b h j d -> b h i d', attn, v) 136 | x = rearrange(x, 'b h (x y) d -> b x y (h d)', h=h, y=y) 137 | x = self.to_out(x, training=training) 138 | 139 | return x 140 | 141 | class Transformer(Layer): 142 | def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult=2, dropout=0.0, dim_out=None, downsample=False): 143 | super(Transformer, self).__init__() 144 | 145 | dim_out = default(dim_out, dim) 146 | self.attn_residual = (not downsample) and dim == dim_out 147 | self.layers = [] 148 | 149 | for _ in range(depth): 150 | self.layers.append([ 151 | Attention(dim, fmap_size=fmap_size, heads=heads, dim_key=dim_key, dim_value=dim_value, 152 | dropout=dropout, downsample=downsample, dim_out=dim_out), 153 | MLP(dim_out, mlp_mult, dropout=dropout) 154 | ]) 155 | 156 | def call(self, x, training=True): 157 | for attn, mlp in self.layers: 158 | attn_res = (x if self.attn_residual else 0) 159 | x = attn(x, training=training) + attn_res 160 | x = mlp(x, training=training) + x 161 | 162 | return x 163 | 164 | class LeViT(Model): 165 | def __init__(self, 166 | image_size, 167 | num_classes, 168 | dim, 169 | depth, 170 | heads, 171 | mlp_mult, 172 | stages=3, 173 | dim_key=32, 174 | dim_value=64, 175 | dropout=0.0, 176 | num_distill_classes=None 177 | ): 178 | super(LeViT, self).__init__() 179 | 180 | dims = cast_tuple(dim, stages) 181 | depths = cast_tuple(depth, stages) 182 | layer_heads = cast_tuple(heads, stages) 183 | 184 | assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), \ 185 | 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages' 186 | 187 | self.conv_embedding = Sequential([ 188 | nn.Conv2D(filters=32, kernel_size=3, strides=2, padding='SAME'), 189 | nn.Conv2D(filters=64, kernel_size=3, strides=2, padding='SAME'), 190 | nn.Conv2D(filters=128, kernel_size=3, strides=2, padding='SAME'), 191 | nn.Conv2D(filters=dims[0], kernel_size=3, strides=2, padding='SAME') 192 | ]) 193 | 194 | fmap_size = image_size // (2 ** 4) 195 | self.backbone = Sequential() 196 | 197 | for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads): 198 | is_last = ind == (stages - 1) 199 | self.backbone.add(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout)) 200 | 201 | if not is_last: 202 | next_dim = dims[ind + 1] 203 | self.backbone.add(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out=next_dim, downsample=True)) 204 | fmap_size = ceil(fmap_size / 2) 205 | 206 | self.pool = Sequential([ 207 | nn.GlobalAvgPool2D() 208 | ]) 209 | 210 | self.distill_head = nn.Dense(units=num_distill_classes) if exists(num_distill_classes) else always(None) 211 | self.mlp_head = nn.Dense(units=num_classes) 212 | 213 | 214 | def call(self, img, training=True, **kwargs): 215 | x = self.conv_embedding(img) 216 | 217 | x = self.backbone(x) 218 | 219 | x = self.pool(x) 220 | out = self.mlp_head(x) 221 | distill = self.distill_head(x) 222 | 223 | if exists(distill): 224 | return out, distill 225 | 226 | return out 227 | 228 | # """ Usage 229 | levit = LeViT( 230 | image_size = 224, 231 | num_classes = 1000, 232 | stages = 3, # number of stages 233 | dim = (256, 384, 512), # dimensions at each stage 234 | depth = 4, # transformer of depth 4 at each stage 235 | heads = (4, 6, 8), # heads at each stage 236 | mlp_mult = 2, 237 | dropout = 0.1 238 | ) 239 | 240 | img = tf.random.normal(shape=[1, 224, 224, 3]) 241 | preds = levit(img) # (1, 1000) 242 | # """ 243 | -------------------------------------------------------------------------------- /vit_tensorflow/mae.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import repeat 8 | from vit import Transformer, ViT 9 | 10 | class Identity(Layer): 11 | def __init__(self): 12 | super(Identity, self).__init__() 13 | 14 | def call(self, x, training=True): 15 | return tf.identity(x) 16 | 17 | class MAE(Model): 18 | def __init__(self, 19 | image_size, 20 | encoder, 21 | decoder_dim, 22 | masking_ratio=0.75, 23 | decoder_depth=1, 24 | decoder_heads=8, 25 | decoder_dim_head=64 26 | ): 27 | super(MAE, self).__init__() 28 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 29 | self.masking_ratio = masking_ratio 30 | 31 | # build 32 | encoder.build(input_shape=(1, image_size, image_size, 3)) 33 | 34 | # extract some hyperparameters and functions from encoder (vision transformer to be trained) 35 | self.encoder = encoder 36 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 37 | self.to_patch, self.patch_to_emb = encoder.patch_embedding.layers[:2] 38 | pixel_values_per_patch = self.patch_to_emb.weights[0].shape[0] 39 | 40 | # decoder parameters 41 | self.enc_to_dec = nn.Dense(units=decoder_dim) if encoder_dim != decoder_dim else Identity() 42 | self.mask_token = tf.Variable(tf.random.normal([decoder_dim])) 43 | self.decoder = Transformer(dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_dim=decoder_dim * 4) 44 | self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) 45 | self.to_pixels = nn.Dense(units=pixel_values_per_patch) 46 | 47 | def call(self, img, training=True, **kwargs): 48 | # get patches 49 | patches = self.to_patch(img, training=training) 50 | batch, num_patches, *_ = patches.shape 51 | 52 | # patch to encoder tokens and add positions 53 | tokens = self.patch_to_emb(patches, training=training) 54 | tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] 55 | 56 | # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked 57 | num_masked = int(self.masking_ratio * num_patches) 58 | rand_indices = tf.argsort(tf.random.uniform([batch, num_patches]), axis=-1) 59 | masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:] 60 | 61 | # get the unmasked tokens to be encoded 62 | batch_range = tf.range(batch)[:, None] 63 | tokens = tokens.numpy()[batch_range, unmasked_indices] 64 | 65 | # get the patches to be masked for the final reconstruction loss 66 | masked_patches = patches.numpy()[batch_range, masked_indices] 67 | 68 | # attend with vision transformer 69 | encoded_tokens = self.encoder.transformer(tokens, training=training) 70 | 71 | # project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder 72 | decoder_tokens = self.enc_to_dec(encoded_tokens, training=training) 73 | 74 | # reapply decoder position embedding to unmasked tokens 75 | decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices, training=training) 76 | 77 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above 78 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_masked) 79 | mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices, training=training) 80 | 81 | # concat the masked tokens to the decoder tokens and attend with decoder 82 | decoder_tokens = tf.concat([mask_tokens, decoder_tokens], axis=1) 83 | decoded_tokens = self.decoder(decoder_tokens, training=training) 84 | 85 | # splice out the mask tokens and project to pixel values 86 | mask_tokens = decoded_tokens[:, :num_masked] 87 | pred_pixel_values = self.to_pixels(mask_tokens, training=training) 88 | 89 | # calculate reconstruction loss 90 | recon_loss = tf.reduce_mean(tf.square(pred_pixel_values, masked_patches)) 91 | 92 | return recon_loss 93 | 94 | v = ViT( 95 | image_size = 256, 96 | patch_size = 32, 97 | num_classes = 1000, 98 | dim = 1024, 99 | depth = 6, 100 | heads = 8, 101 | mlp_dim = 2048 102 | ) 103 | 104 | mae = MAE( 105 | image_size = 256, 106 | encoder = v, 107 | masking_ratio = 0.75, # the paper recommended 75% masked patches 108 | decoder_dim = 512, # paper showed good results with just 512 109 | decoder_depth = 6 # anywhere from 1 to 8 110 | ) 111 | 112 | img = tf.random.normal(shape=[8, 256, 256, 3]) 113 | loss = mae(img) 114 | print(loss) -------------------------------------------------------------------------------- /vit_tensorflow/mobile_vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange 8 | from einops.layers.tensorflow import Reduce 9 | 10 | 11 | def gelu(x, approximate=False): 12 | if approximate: 13 | coeff = tf.cast(0.044715, x.dtype) 14 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 15 | else: 16 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 17 | 18 | 19 | class GELU(Layer): 20 | def __init__(self, approximate=False): 21 | super(GELU, self).__init__() 22 | self.approximate = approximate 23 | 24 | def call(self, x, training=True): 25 | return gelu(x, self.approximate) 26 | 27 | 28 | class Swish(Layer): 29 | def __init__(self): 30 | super(Swish, self).__init__() 31 | 32 | def call(self, x, training=True): 33 | x = tf.keras.activations.swish(x) 34 | return x 35 | 36 | 37 | class Conv_NxN_BN(Layer): 38 | def __init__(self, dim, kernel_size=1, stride=1): 39 | super(Conv_NxN_BN, self).__init__() 40 | 41 | self.layers = Sequential([ 42 | nn.Conv2D(filters=dim, kernel_size=kernel_size, strides=stride, padding='SAME', use_bias=False), 43 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5), 44 | Swish() 45 | ]) 46 | 47 | def call(self, x, training=True): 48 | x = self.layers(x, training=training) 49 | return x 50 | 51 | 52 | class PreNorm(Layer): 53 | def __init__(self, fn): 54 | super(PreNorm, self).__init__() 55 | 56 | self.norm = nn.LayerNormalization() 57 | self.fn = fn 58 | 59 | def call(self, x, training=True): 60 | return self.fn(self.norm(x), training=training) 61 | 62 | 63 | class MLP(Layer): 64 | def __init__(self, dim, hidden_dim, dropout=0.0): 65 | super(MLP, self).__init__() 66 | 67 | self.net = Sequential([ 68 | nn.Dense(units=hidden_dim), 69 | Swish(), 70 | nn.Dropout(rate=dropout), 71 | nn.Dense(units=dim), 72 | nn.Dropout(rate=dropout) 73 | ]) 74 | 75 | def call(self, x, training=True): 76 | return self.net(x, training=training) 77 | 78 | 79 | class Attention(Layer): 80 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 81 | super(Attention, self).__init__() 82 | 83 | inner_dim = dim_head * heads 84 | self.heads = heads 85 | self.scale = dim_head ** -0.5 86 | 87 | self.attend = nn.Softmax() 88 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 89 | 90 | self.to_out = Sequential([ 91 | nn.Dense(units=dim), 92 | nn.Dropout(rate=dropout) 93 | ]) 94 | 95 | def call(self, x, training=True): 96 | qkv = self.to_qkv(x) 97 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 98 | 99 | q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) 100 | 101 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 102 | attn = self.attend(dots) 103 | out = tf.matmul(attn, v) 104 | out = rearrange(out, 'b p h n d -> b p n (h d)') 105 | out = self.to_out(out, training=training) 106 | 107 | return out 108 | 109 | 110 | class Transformer(Layer): 111 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 112 | super(Transformer, self).__init__() 113 | 114 | self.layers = [] 115 | 116 | for _ in range(depth): 117 | self.layers.append([ 118 | PreNorm(Attention(dim, heads, dim_head, dropout)), 119 | PreNorm(MLP(dim, mlp_dim, dropout)) 120 | ]) 121 | 122 | def call(self, x, training=True): 123 | for attn, ff in self.layers: 124 | x = attn(x, training=training) + x 125 | x = ff(x, training=training) + x 126 | 127 | return x 128 | 129 | 130 | class MV2Block(Layer): 131 | def __init__(self, dim_in, dim_out, stride=1, expansion=4): 132 | super(MV2Block, self).__init__() 133 | 134 | assert stride in [1, 2] 135 | 136 | hidden_dim = int(dim_in * expansion) 137 | self.use_res_connect = stride == 1 and dim_in == dim_out 138 | 139 | if expansion == 1: 140 | self.conv = Sequential([ 141 | # dw 142 | nn.Conv2D(filters=hidden_dim, kernel_size=3, strides=stride, padding='SAME', groups=hidden_dim, 143 | use_bias=False), 144 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5), 145 | Swish(), 146 | # pw-linear 147 | nn.Conv2D(filters=dim_out, kernel_size=1, strides=1, use_bias=False), 148 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5) 149 | ]) 150 | else: 151 | self.conv = Sequential([ 152 | # pw 153 | nn.Conv2D(filters=hidden_dim, kernel_size=1, strides=1, use_bias=False), 154 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5), 155 | Swish(), 156 | # dw 157 | nn.Conv2D(filters=hidden_dim, kernel_size=3, strides=stride, padding='SAME', groups=hidden_dim, 158 | use_bias=False), 159 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5), 160 | Swish(), 161 | # pw-linear 162 | nn.Conv2D(filters=dim_out, kernel_size=1, strides=1, use_bias=False), 163 | nn.BatchNormalization(momentum=0.9, epsilon=1e-5) 164 | ]) 165 | 166 | def call(self, x, training=True): 167 | out = self.conv(x, training=training) 168 | if self.use_res_connect: 169 | out = out + x 170 | return out 171 | 172 | 173 | class MobileViTBlock(Layer): 174 | def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.0): 175 | super(MobileViTBlock, self).__init__() 176 | 177 | self.ph, self.pw = patch_size 178 | 179 | self.conv1 = Conv_NxN_BN(channel, kernel_size=kernel_size, stride=1) 180 | self.conv2 = Conv_NxN_BN(dim, kernel_size=1, stride=1) 181 | 182 | self.transformer = Transformer(dim=dim, depth=depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) 183 | 184 | self.conv3 = Conv_NxN_BN(channel, kernel_size=1, stride=1) 185 | self.conv4 = Conv_NxN_BN(channel, kernel_size=kernel_size, stride=1) 186 | 187 | def call(self, x, training=True): 188 | y = tf.identity(x) 189 | 190 | # Local representations 191 | x = self.conv1(x, training=training) 192 | x = self.conv2(x, training=training) 193 | 194 | # Global representations 195 | _, h, w, c = x.shape 196 | x = rearrange(x, 'b (h ph) (w pw) d -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) 197 | x = self.transformer(x, training=training) 198 | x = rearrange(x, 'b (ph pw) (h w) d -> b (h ph) (w pw) d', h=h // self.ph, w=w // self.pw, ph=self.ph, 199 | pw=self.pw) 200 | 201 | # Fusion 202 | x = self.conv3(x, training=training) 203 | x = tf.concat([x, y], axis=-1) 204 | x = self.conv4(x, training=training) 205 | 206 | return x 207 | 208 | 209 | class MobileViT(Model): 210 | def __init__(self, 211 | image_size, 212 | dims, 213 | channels, 214 | num_classes, 215 | expansion=4, 216 | kernel_size=3, 217 | patch_size=(2, 2), 218 | depths=(2, 4, 3) 219 | ): 220 | super(MobileViT, self).__init__() 221 | assert len(dims) == 3, 'dims must be a tuple of 3' 222 | assert len(depths) == 3, 'depths must be a tuple of 3' 223 | 224 | ih, iw = image_size 225 | ph, pw = patch_size 226 | assert ih % ph == 0 and iw % pw == 0 227 | 228 | init_dim, *_, last_dim = channels 229 | 230 | self.conv1 = Conv_NxN_BN(init_dim, kernel_size=3, stride=2) 231 | 232 | self.stem = Sequential() 233 | self.stem.add(MV2Block(channels[0], channels[1], stride=1, expansion=expansion)) 234 | self.stem.add(MV2Block(channels[1], channels[2], stride=2, expansion=expansion)) 235 | self.stem.add(MV2Block(channels[2], channels[3], stride=1, expansion=expansion)) 236 | self.stem.add(MV2Block(channels[2], channels[3], stride=1, expansion=expansion)) 237 | 238 | self.trunk = [] 239 | self.trunk.append([ 240 | MV2Block(channels[3], channels[4], stride=2, expansion=expansion), 241 | MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, mlp_dim=int(dims[0] * 2)) 242 | ]) 243 | 244 | self.trunk.append([ 245 | MV2Block(channels[5], channels[6], stride=2, expansion=expansion), 246 | MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, mlp_dim=int(dims[1] * 4)) 247 | ]) 248 | 249 | self.trunk.append([ 250 | MV2Block(channels[7], channels[8], stride=2, expansion=expansion), 251 | MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, mlp_dim=int(dims[2] * 4)) 252 | ]) 253 | 254 | self.to_logits = Sequential([ 255 | Conv_NxN_BN(last_dim, kernel_size=1, stride=1), 256 | Reduce('b h w c -> b c', 'mean'), 257 | nn.Dense(units=num_classes, use_bias=False) 258 | ]) 259 | 260 | def call(self, x, training=True, **kwargs): 261 | x = self.conv1(x, training=training) 262 | 263 | x = self.stem(x, training=training) 264 | 265 | for conv, attn in self.trunk: 266 | x = conv(x, training=training) 267 | x = attn(x, training=training) 268 | 269 | x = self.to_logits(x, training=training) 270 | 271 | return x 272 | 273 | """ Usage 274 | v = MobileViT( 275 | image_size=(256, 256), 276 | dims=[96, 120, 144], 277 | channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], 278 | num_classes=1000 279 | ) 280 | 281 | img = tf.random.normal(shape=[1, 256, 256, 3]) 282 | preds = v(img) # (1, 1000) 283 | """ -------------------------------------------------------------------------------- /vit_tensorflow/mpp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange, repeat, reduce 8 | import numpy as np 9 | import math 10 | from vit import ViT 11 | 12 | def scatter_numpy(x, dim, index, src): 13 | """ 14 | Writes all values from the Tensor src into x at the indices specified in the index Tensor. 15 | 16 | :param dim: The axis along which to index 17 | :param index: The indices of elements to scatter 18 | :param src: The source element(s) to scatter 19 | :return: x 20 | """ 21 | 22 | if x.ndim != index.ndim: 23 | raise ValueError("Index should have the same number of dimensions as output") 24 | if dim >= x.ndim or dim < -x.ndim: 25 | raise IndexError("dim is out of range") 26 | if dim < 0: 27 | # Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter 28 | dim = x.ndim + dim 29 | idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 30 | self_xsection_shape = x.shape[:dim] + x.shape[dim + 1:] 31 | if idx_xsection_shape != self_xsection_shape: 32 | raise ValueError("Except for dimension " + str(dim) + 33 | ", all dimensions of index and output should be the same size") 34 | if (index >= x.shape[dim]).any() or (index < 0).any(): 35 | raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)") 36 | 37 | def make_slice(arr, dim, i): 38 | slc = [slice(None)] * arr.ndim 39 | slc[dim] = i 40 | slc = tuple(slc) 41 | return slc 42 | 43 | # We use index and dim parameters to create idx 44 | # idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self 45 | idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1), 46 | index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])] 47 | idx = list(np.concatenate(idx, axis=1)) 48 | idx.insert(dim, idx.pop()) 49 | 50 | if not np.isscalar(src): 51 | if index.shape[dim] > src.shape[dim]: 52 | raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ") 53 | src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:] 54 | if idx_xsection_shape != src_xsection_shape: 55 | raise ValueError("Except for dimension " + 56 | str(dim) + ", all dimensions of index and src should be the same size") 57 | # src_idx is a NumPy advanced index for indexing of elements in the src 58 | src_idx = list(idx) 59 | src_idx.pop(dim) 60 | src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape))) 61 | idx = tuple(idx) 62 | x[idx] = src[src_idx] 63 | 64 | else: 65 | idx = tuple(idx) 66 | x[idx] = src 67 | 68 | return x 69 | 70 | def exists(val): 71 | return val is not None 72 | 73 | def prob_mask_like(t, prob): 74 | batch, seq_length, _ = t.shape 75 | x = tf.random.uniform([batch, seq_length], dtype=tf.float32) < prob 76 | return x 77 | 78 | def get_mask_subset_with_prob(patched_input, prob): 79 | batch, seq_len, _ = patched_input.shape 80 | max_masked = math.ceil(prob * seq_len) 81 | 82 | rand = tf.random.uniform([batch, seq_len]) 83 | _, sampled_indices = tf.math.top_k(rand, k=max_masked) 84 | 85 | new_mask = tf.zeros([batch, seq_len]) 86 | new_mask = scatter_numpy(new_mask.numpy(), 1, sampled_indices.numpy(), 1) 87 | new_mask = tf.cast(new_mask, tf.bool) 88 | return new_mask 89 | 90 | class MPPLoss(Layer): 91 | def __init__(self, 92 | patch_size, 93 | channels, 94 | output_channel_bits, 95 | max_pixel_val, 96 | mean, 97 | std 98 | ): 99 | super(MPPLoss, self).__init__() 100 | self.patch_size = patch_size 101 | self.channels = channels 102 | self.output_channel_bits = output_channel_bits 103 | self.max_pixel_val = max_pixel_val 104 | 105 | self.mean = tf.reshape(tf.convert_to_tensor(mean, dtype=tf.float32), [-1, 1, 1]) if mean else None 106 | self.std = tf.reshape(tf.convert_to_tensor(std, dtype=tf.float32), [-1, 1, 1]) if std else None 107 | 108 | def call(self, predicted_patches, target=None, mask=None, training=True): 109 | p, c, mpv, bits = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits 110 | bin_size = mpv / (2 ** bits) 111 | 112 | # un-normalize input 113 | if exists(self.mean) and exists(self.std): 114 | target = target * self.std + self.mean 115 | 116 | # reshape target to patches 117 | target = tf.clip_by_value(target, clip_value_min=tf.reduce_min(mpv), clip_value_max=mpv) # clamp just in case 118 | avg_target = reduce(target, 'b (h p1) (w p2) c -> b (h w) c', 'mean', p1=p, p2=p) 119 | 120 | channel_bins = np.arange(bin_size, mpv, bin_size) 121 | discretized_target = tf.compat.v1.raw_ops.Bucketize(input=avg_target, boundaries=channel_bins) 122 | 123 | bin_mask = (2 ** bits) ** tf.range(0, c) 124 | bin_mask = rearrange(bin_mask, 'c -> () () c') 125 | 126 | target_label = tf.reduce_sum(bin_mask * discretized_target, axis=-1, keepdims=True) 127 | 128 | loss = tf.nn.softmax_cross_entropy_with_logits(tf.cast(predicted_patches[mask], tf.float32), tf.cast(target_label[mask], tf.float32)) 129 | loss = tf.reduce_mean(loss) 130 | 131 | return loss 132 | 133 | class MPP(Model): 134 | def __init__(self, 135 | image_size, 136 | transformer, 137 | patch_size, 138 | output_channel_bits=3, 139 | channels=3, 140 | max_pixel_val=1.0, 141 | mask_prob=0.15, 142 | replace_prob=0.5, 143 | random_patch_prob=0.5, 144 | mean=None, 145 | std=None 146 | ): 147 | super(MPP, self).__init__() 148 | # build 149 | transformer.build(input_shape=(20, image_size, image_size, 3)) 150 | 151 | self.transformer = transformer 152 | self.loss = MPPLoss(patch_size, channels, output_channel_bits, max_pixel_val, mean, std) 153 | 154 | # output transformation 155 | self.to_bits = nn.Dense(units=2 ** (output_channel_bits * channels)) 156 | 157 | # vit related dimensions 158 | self.patch_size = patch_size 159 | 160 | # mpp related probabilities 161 | self.mask_prob = mask_prob 162 | self.replace_prob = replace_prob 163 | self.random_patch_prob = random_patch_prob 164 | 165 | # token ids 166 | self.mask_token = tf.Variable(tf.random.normal([1, 1, channels * patch_size ** 2])) 167 | 168 | def call(self, inputs, training=True, **kwargs): 169 | 170 | transformer = self.transformer 171 | # clone original image for loss 172 | img = tf.stop_gradient(tf.identity(inputs)) 173 | 174 | # reshape raw image to patches 175 | p = self.patch_size 176 | inputs = rearrange(inputs,'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=p, p2=p) 177 | 178 | mask = get_mask_subset_with_prob(inputs, self.mask_prob) 179 | 180 | # mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob) 181 | masked_input = tf.stop_gradient(tf.identity(inputs)) 182 | 183 | # if random token probability > 0 for mpp 184 | if self.random_patch_prob > 0: 185 | random_patch_sampling_prob = self.random_patch_prob / (1 - self.replace_prob) 186 | random_patch_prob = prob_mask_like(inputs, random_patch_sampling_prob) 187 | 188 | bool_random_patch_prob = tf.cast(tf.cast(mask, tf.float32) * tf.cast((random_patch_prob == True), tf.float32), tf.bool).numpy() 189 | random_patches = tf.random.uniform(shape=[inputs.shape[0], inputs.shape[1]], minval=0, maxval=inputs.shape[1], dtype=tf.int32) 190 | 191 | randomized_input = masked_input.numpy()[tf.expand_dims(tf.range(masked_input.shape[0]), axis=-1), random_patches] 192 | masked_input.numpy()[bool_random_patch_prob] = randomized_input[bool_random_patch_prob] 193 | 194 | # [mask] input 195 | replace_prob = prob_mask_like(inputs, self.replace_prob) 196 | bool_mask_replace = tf.cast(((tf.cast(mask, tf.float32) * tf.cast(replace_prob, tf.float32)) == True), tf.int32) 197 | masked_input.numpy()[bool_mask_replace.numpy()] = self.mask_token.numpy() 198 | 199 | # linear embedding of patches 200 | masked_input = transformer.patch_embedding.layers[-1](masked_input, training=training) 201 | 202 | # add cls token to input sequence 203 | b, n, _ = masked_input.shape 204 | cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) 205 | masked_input = tf.concat([cls_tokens, masked_input], axis=1) 206 | 207 | # add positional embeddings to input 208 | masked_input += transformer.pos_embedding[:, :(n + 1)] 209 | masked_input = transformer.dropout(masked_input, training=training) 210 | 211 | # get generator output and get mpp loss 212 | masked_input = transformer.transformer(masked_input, training=training) 213 | cls_logits = self.to_bits(masked_input) 214 | logits = cls_logits[:, 1:, :] 215 | 216 | mpp_loss = self.loss(logits, img, mask) 217 | 218 | return mpp_loss 219 | 220 | 221 | model = ViT( 222 | image_size=256, 223 | patch_size=32, 224 | num_classes=1000, 225 | dim=1024, 226 | depth=6, 227 | heads=8, 228 | mlp_dim=2048, 229 | dropout=0.1, 230 | emb_dropout=0.1 231 | ) 232 | 233 | mpp_trainer = MPP( 234 | image_size=256, 235 | transformer=model, 236 | patch_size=32, 237 | mask_prob=0.15, # probability of using token in masked prediction task 238 | random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp 239 | replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token 240 | ) 241 | 242 | 243 | """ Usage 244 | def sample_unlabelled_images(): 245 | return tf.random.normal([20, 256, 256, 3]) 246 | 247 | for _ in range(100): 248 | with tf.GradientTape() as tape: 249 | images = sample_unlabelled_images() 250 | loss = mpp_trainer(images) 251 | """ -------------------------------------------------------------------------------- /vit_tensorflow/nest.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow import einsum 7 | 8 | from einops import rearrange 9 | from einops.layers.tensorflow import Rearrange, Reduce 10 | 11 | def cast_tuple(val, depth): 12 | return val if isinstance(val, tuple) else ((val,) * depth) 13 | 14 | def gelu(x, approximate=False): 15 | if approximate: 16 | coeff = tf.cast(0.044715, x.dtype) 17 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 18 | else: 19 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 20 | 21 | class Identity(Layer): 22 | def __init__(self): 23 | super(Identity, self).__init__() 24 | 25 | def call(self, x, training=True): 26 | return tf.identity(x) 27 | 28 | class LayerNorm(Layer): 29 | def __init__(self, dim, eps=1e-5): 30 | super(LayerNorm, self).__init__() 31 | self.eps = eps 32 | 33 | self.g = tf.Variable(tf.ones([1, 1, 1, dim])) 34 | self.b = tf.Variable(tf.zeros([1, 1, 1, dim])) 35 | 36 | def call(self, x, training=True): 37 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 38 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 39 | 40 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b 41 | return x 42 | 43 | class PreNorm(Layer): 44 | def __init__(self, dim, fn): 45 | super(PreNorm, self).__init__() 46 | 47 | self.norm = LayerNorm(dim) 48 | self.fn = fn 49 | 50 | def call(self, x, training=True): 51 | return self.fn(self.norm(x), training=training) 52 | 53 | class GELU(Layer): 54 | def __init__(self, approximate=False): 55 | super(GELU, self).__init__() 56 | self.approximate = approximate 57 | 58 | def call(self, x, training=True): 59 | return gelu(x, self.approximate) 60 | 61 | class MLP(Layer): 62 | def __init__(self, dim, mult=4, dropout=0.0): 63 | super(MLP, self).__init__() 64 | 65 | self.net = [ 66 | nn.Conv2D(filters=dim * mult, kernel_size=1, strides=1), 67 | GELU(), 68 | nn.Dropout(rate=dropout), 69 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 70 | nn.Dropout(rate=dropout) 71 | ] 72 | self.net = Sequential(self.net) 73 | 74 | def call(self, x, training=True): 75 | return self.net(x, training=training) 76 | 77 | class Attention(Layer): 78 | def __init__(self, dim, heads=8, dropout=0.0): 79 | super(Attention, self).__init__() 80 | dim_head = dim // heads 81 | inner_dim = dim_head * heads 82 | self.heads = heads 83 | self.scale = dim_head ** -0.5 84 | 85 | self.attend = nn.Softmax() 86 | self.to_qkv = nn.Conv2D(filters=inner_dim * 3, kernel_size=1, strides=1, use_bias=False) 87 | 88 | self.to_out = Sequential([ 89 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 90 | nn.Dropout(rate=dropout) 91 | ]) 92 | 93 | def call(self, x, training=True): 94 | b, h, w, c = x.shape 95 | heads = self.heads 96 | 97 | qkv = self.to_qkv(x) 98 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 99 | q, k, v = map(lambda t: rearrange(t, 'b x y (h d) -> b h (x y) d', h=heads), qkv) 100 | 101 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 102 | 103 | attn = self.attend(dots) 104 | 105 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 106 | out = rearrange(out, 'b h (x y) d -> b x y (h d)', x=h, y=w) 107 | out = self.to_out(out, training=training) 108 | 109 | return out 110 | 111 | class Aggregate(Layer): 112 | def __init__(self, dim): 113 | super(Aggregate, self).__init__() 114 | 115 | self.ag_layers = Sequential([ 116 | nn.Conv2D(filters=dim, kernel_size=3, strides=1, padding='SAME'), 117 | LayerNorm(dim), 118 | nn.MaxPool2D(pool_size=3, strides=2, padding='SAME') 119 | ]) 120 | 121 | def call(self, x, training=True): 122 | x = self.ag_layers(x) 123 | return x 124 | 125 | class Transformer(Layer): 126 | def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout=0.0): 127 | super(Transformer, self).__init__() 128 | self.layers = [] 129 | self.pos_emb = tf.Variable(tf.random.normal([seq_len])) 130 | 131 | for _ in range(depth): 132 | self.layers.append([ 133 | PreNorm(dim, Attention(dim, heads=heads, dropout=dropout)), 134 | PreNorm(dim, MLP(dim, mlp_mult, dropout=dropout)) 135 | ]) 136 | 137 | def call(self, x, training=True): 138 | _, h, w, c = x.shape 139 | 140 | pos_emb = self.pos_emb[:(h * w)] 141 | pos_emb = rearrange(pos_emb, '(h w) -> () h w ()', h = h, w = w) 142 | x = x + pos_emb 143 | 144 | for attn, ff in self.layers: 145 | x = attn(x, training=training) + x 146 | x = ff(x, training=training) + x 147 | 148 | return x 149 | 150 | class NesT(Model): 151 | def __init__(self, 152 | image_size, 153 | patch_size, 154 | num_classes, 155 | dim, 156 | heads, 157 | num_hierarchies, 158 | block_repeats, 159 | mlp_mult=4, 160 | dropout=0.0 161 | ): 162 | super(NesT, self).__init__() 163 | assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.' 164 | fmap_size = image_size // patch_size 165 | blocks = 2 ** (num_hierarchies - 1) 166 | 167 | seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy 168 | hierarchies = list(reversed(range(num_hierarchies))) 169 | mults = [2 ** i for i in reversed(hierarchies)] 170 | 171 | layer_heads = list(map(lambda t: t * heads, mults)) 172 | layer_dims = list(map(lambda t: t * dim, mults)) 173 | last_dim = layer_dims[-1] 174 | 175 | layer_dims = [*layer_dims, layer_dims[-1]] 176 | dim_pairs = zip(layer_dims[:-1], layer_dims[1:]) 177 | 178 | self.patch_embedding = Sequential([ 179 | Rearrange('b (h p1) (w p2) c -> b h w (p1 p2 c) ', p1=patch_size, p2=patch_size), 180 | nn.Conv2D(filters=layer_dims[0], kernel_size=1, strides=1) 181 | ]) 182 | 183 | block_repeats = cast_tuple(block_repeats, num_hierarchies) 184 | 185 | self.nest_layers = [] 186 | 187 | for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats): 188 | is_last = level == 0 189 | depth = block_repeat 190 | 191 | self.nest_layers.append([ 192 | Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout), 193 | Aggregate(dim_out) if not is_last else Identity() 194 | ]) 195 | 196 | self.mlp_head = Sequential([ 197 | LayerNorm(last_dim), 198 | Reduce('b h w c -> b c', 'mean'), 199 | nn.Dense(units=num_classes) 200 | ]) 201 | 202 | def call(self, img, training=True, **kwargs): 203 | x = self.patch_embedding(img) 204 | 205 | num_hierarchies = len(self.nest_layers) 206 | 207 | for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.nest_layers): 208 | block_size = 2 ** level 209 | x = rearrange(x, 'b (b1 h) (b2 w) c -> (b b1 b2) h w c', b1 = block_size, b2 = block_size) 210 | x = transformer(x, training=training) 211 | x = rearrange(x, '(b b1 b2) h w c -> b (b1 h) (b2 w) c', b1 = block_size, b2 = block_size) 212 | x = aggregate(x) 213 | 214 | x = self.mlp_head(x) 215 | 216 | return x 217 | 218 | """ Usage 219 | v = NesT( 220 | image_size = 224, 221 | patch_size = 4, 222 | dim = 96, 223 | heads = 3, 224 | num_hierarchies = 3, # number of hierarchies 225 | block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom 226 | num_classes = 1000 227 | ) 228 | 229 | img = tf.random.normal(shape=[1, 224, 224, 3]) 230 | preds = v(img) # (1, 1000) 231 | """ 232 | 233 | 234 | -------------------------------------------------------------------------------- /vit_tensorflow/parallel_vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.tensorflow import Rearrange 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def gelu(x, approximate=False): 14 | if approximate: 15 | coeff = tf.cast(0.044715, x.dtype) 16 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 17 | else: 18 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 19 | 20 | 21 | class GELU(Layer): 22 | def __init__(self, approximate=False): 23 | super(GELU, self).__init__() 24 | self.approximate = approximate 25 | 26 | def call(self, x, training=True): 27 | return gelu(x, self.approximate) 28 | 29 | class Identity(Layer): 30 | def __init__(self): 31 | super(Identity, self).__init__() 32 | 33 | def call(self, x, training=True): 34 | return tf.identity(x) 35 | 36 | class Parallel(Layer): 37 | def __init__(self, fns): 38 | super(Parallel, self).__init__() 39 | self.fns = fns 40 | 41 | def call(self, x, training=True): 42 | return sum([fn(x, training=training) for fn in self.fns]) 43 | 44 | class PreNorm(Layer): 45 | def __init__(self, fn): 46 | super(PreNorm, self).__init__() 47 | 48 | self.norm = nn.LayerNormalization() 49 | self.fn = fn 50 | 51 | def call(self, x, **kwargs): 52 | return self.fn(self.norm(x), **kwargs) 53 | 54 | class MLP(Layer): 55 | def __init__(self, dim, hidden_dim, dropout=0.0): 56 | super(MLP, self).__init__() 57 | self.net = Sequential([ 58 | nn.Dense(units=hidden_dim), 59 | GELU(), 60 | nn.Dropout(rate=dropout), 61 | nn.Dense(units=dim), 62 | nn.Dropout(rate=dropout) 63 | ]) 64 | 65 | def call(self, x, training=True): 66 | return self.net(x, training=training) 67 | 68 | class Attention(Layer): 69 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 70 | super(Attention, self).__init__() 71 | inner_dim = dim_head * heads 72 | project_out = not (heads == 1 and dim_head == dim) 73 | 74 | self.heads = heads 75 | self.scale = dim_head ** -0.5 76 | 77 | self.attend = nn.Softmax() 78 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 79 | 80 | self.to_out = Sequential([ 81 | nn.Dense(units=dim), 82 | nn.Dropout(rate=dropout) 83 | ]) if project_out else Identity() 84 | 85 | def call(self, x, training=True): 86 | qkv = self.to_qkv(x) 87 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 88 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 89 | 90 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 91 | attn = self.attend(dots) 92 | 93 | x = tf.matmul(attn, v) 94 | x = rearrange(x, 'b h n d -> b n (h d)') 95 | x = self.to_out(x, training=training) 96 | 97 | return x 98 | 99 | class Transformer(Layer): 100 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches=2, dropout=0.0): 101 | super(Transformer, self).__init__() 102 | self.layers = [] 103 | 104 | # attn_block = lambda: PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)) 105 | # ff_block = lambda: PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 106 | 107 | for _ in range(depth): 108 | self.layers.append([ 109 | Parallel([PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)) for _ in range(num_parallel_branches)]), 110 | Parallel([PreNorm(MLP(dim, mlp_dim, dropout=dropout)) for _ in range(num_parallel_branches)]) 111 | ]) 112 | 113 | def call(self, x, training=True): 114 | for attns, ffs in self.layers: 115 | x = attns(x, training=training) + x 116 | x = ffs(x, training=training) + x 117 | return x 118 | 119 | class ViT(Model): 120 | def __init__(self, 121 | image_size, 122 | patch_size, 123 | num_classes, 124 | dim, 125 | depth, 126 | heads, 127 | mlp_dim, 128 | pool='cls', 129 | num_parallel_branches=2, 130 | dim_head=64, 131 | dropout=0.0, 132 | emb_dropout=0.0 133 | ): 134 | super(ViT, self).__init__() 135 | image_height, image_width = pair(image_size) 136 | patch_height, patch_width = pair(patch_size) 137 | 138 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 139 | 140 | num_patches = (image_height // patch_height) * (image_width // patch_width) 141 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 142 | 143 | self.patch_embedding = Sequential([ 144 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 145 | nn.Dense(units=dim) 146 | ]) 147 | 148 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 149 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 150 | self.dropout = nn.Dropout(rate=emb_dropout) 151 | 152 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout) 153 | 154 | self.pool = pool 155 | 156 | self.mlp_head = Sequential([ 157 | nn.LayerNormalization(), 158 | nn.Dense(units=num_classes) 159 | ]) 160 | 161 | def call(self, img, training=True, **kwargs): 162 | x = self.patch_embedding(img) 163 | b, n, d = x.shape 164 | 165 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 166 | x = tf.concat([cls_tokens, x], axis=1) 167 | x += self.pos_embedding[:, :(n + 1)] 168 | x = self.dropout(x, training=training) 169 | x = self.transformer(x, training=training) 170 | 171 | if self.pool == 'mean': 172 | x = tf.reduce_mean(x, axis=1) 173 | else: 174 | x = x[:, 0] 175 | 176 | x = self.mlp_head(x) 177 | 178 | return x 179 | 180 | """ Usage 181 | v = ViT( 182 | image_size = 256, 183 | patch_size = 16, 184 | num_classes = 1000, 185 | dim = 1024, 186 | depth = 6, 187 | heads = 8, 188 | mlp_dim = 2048, 189 | num_parallel_branches = 2, # in paper, they claimed 2 was optimal 190 | dropout = 0.1, 191 | emb_dropout = 0.1 192 | ) 193 | 194 | img = tf.random.normal(shape=[4, 256, 256, 3]) 195 | preds = v(img) # (4, 1000) 196 | """ -------------------------------------------------------------------------------- /vit_tensorflow/pit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from tensorflow import einsum 8 | from einops import rearrange, repeat 9 | 10 | from math import sqrt 11 | 12 | 13 | def cast_tuple(val, num): 14 | return val if isinstance(val, tuple) else (val,) * num 15 | 16 | def conv_output_size(image_size, kernel_size, stride, padding = 0): 17 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 18 | 19 | class PreNorm(Layer): 20 | def __init__(self, fn): 21 | super(PreNorm, self).__init__() 22 | 23 | self.norm = nn.LayerNormalization() 24 | self.fn = fn 25 | 26 | def call(self, x, training=True): 27 | return self.fn(self.norm(x), training=training) 28 | 29 | class MLP(Layer): 30 | def __init__(self, dim, hidden_dim, dropout=0.0): 31 | super(MLP, self).__init__() 32 | 33 | def GELU(): 34 | def gelu(x, approximate=False): 35 | if approximate: 36 | coeff = tf.cast(0.044715, x.dtype) 37 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 38 | else: 39 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 40 | 41 | return nn.Activation(gelu) 42 | 43 | self.net = [ 44 | nn.Dense(units=hidden_dim), 45 | GELU(), 46 | nn.Dropout(rate=dropout), 47 | nn.Dense(units=dim), 48 | nn.Dropout(rate=dropout) 49 | ] 50 | self.net = Sequential(self.net) 51 | 52 | def call(self, x, training=True): 53 | return self.net(x, training=training) 54 | 55 | class Attention(Layer): 56 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 57 | super(Attention, self).__init__() 58 | inner_dim = dim_head * heads 59 | project_out = not (heads == 1 and dim_head == dim) 60 | 61 | self.heads = heads 62 | self.scale = dim_head ** -0.5 63 | 64 | self.attend = nn.Softmax() 65 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 66 | 67 | if project_out: 68 | self.to_out = [ 69 | nn.Dense(units=dim), 70 | nn.Dropout(rate=dropout) 71 | ] 72 | else: 73 | self.to_out = [] 74 | 75 | self.to_out = Sequential(self.to_out) 76 | 77 | def call(self, x, training=True): 78 | qkv = self.to_qkv(x) 79 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 80 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 81 | 82 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 83 | attn = self.attend(dots) 84 | 85 | x = einsum('b h i j, b h j d -> b h i d', attn, v) 86 | x = rearrange(x, 'b h n d -> b n (h d)') 87 | x = self.to_out(x, training=training) 88 | 89 | return x 90 | 91 | class Transformer(Layer): 92 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 93 | super(Transformer, self).__init__() 94 | 95 | self.layers = [] 96 | 97 | for _ in range(depth): 98 | self.layers.append([ 99 | PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 100 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 101 | ]) 102 | 103 | def call(self, x, training=True): 104 | for attn, mlp in self.layers: 105 | x = attn(x, training=training) + x 106 | x = mlp(x, training=training) + x 107 | 108 | return x 109 | 110 | class Unfold(Layer): 111 | def __init__(self, kernel_size, stride): 112 | super(Unfold, self).__init__() 113 | 114 | self.kernel_size = [1, kernel_size, kernel_size, 1] 115 | self.stride = [1, stride, stride, 1] 116 | self.rates = [1, 1, 1, 1] 117 | 118 | def call(self, x, training=True): 119 | x = tf.image.extract_patches(x, sizes=self.kernel_size, strides=self.stride, rates=self.rates, padding='VALID') 120 | x = rearrange(x, 'b h w c -> b (h w) c') 121 | 122 | return x 123 | 124 | # depthwise convolution, for pooling 125 | class DepthWiseConv2d(Layer): 126 | def __init__(self, dim_in, dim_out, kernel_size, stride, bias=True): 127 | super(DepthWiseConv2d, self).__init__() 128 | 129 | net = [] 130 | net += [nn.Conv2D(filters=dim_out, kernel_size=kernel_size, strides=stride, padding='SAME', groups=dim_in, use_bias=bias)] 131 | net += [nn.Conv2D(filters=dim_out, kernel_size=1, strides=1, use_bias=bias)] 132 | 133 | self.net = Sequential(net) 134 | 135 | def call(self, x, training=True): 136 | x = self.net(x) 137 | return x 138 | 139 | # pooling layer 140 | class Pool(Layer): 141 | def __init__(self, dim): 142 | super(Pool, self).__init__() 143 | self.downsample = DepthWiseConv2d(dim, dim*2, kernel_size=3, stride=2) 144 | self.cls_ff = nn.Dense(units=dim*2) 145 | 146 | def call(self, x, training=True): 147 | cls_token, tokens = x[:, :1], x[:, 1:] 148 | cls_token = self.cls_ff(cls_token) 149 | 150 | tokens = rearrange(tokens, 'b (h w) c -> b h w c', h=int(sqrt(tokens.shape[1]))) 151 | tokens = self.downsample(tokens) 152 | tokens = rearrange(tokens, 'b h w c -> b (h w) c') 153 | 154 | x = tf.concat([cls_token, tokens], axis=1) 155 | 156 | return x 157 | 158 | class PiT(Model): 159 | def __init__(self, 160 | image_size, 161 | patch_size, 162 | num_classes, 163 | dim, 164 | depth, 165 | heads, 166 | mlp_dim, 167 | dim_head=64, 168 | dropout=0.0, 169 | emb_dropout=0.0, 170 | ): 171 | super(PiT, self).__init__() 172 | 173 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 174 | assert isinstance(depth, 175 | tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing' 176 | 177 | heads = cast_tuple(heads, len(depth)) 178 | 179 | self.patch_embedding = Sequential([ 180 | Unfold(kernel_size=patch_size, stride=patch_size // 2), 181 | nn.Dense(units=dim) 182 | ], name='patch_embedding') 183 | 184 | output_size = conv_output_size(image_size, patch_size, patch_size // 2) 185 | num_patches = output_size ** 2 186 | 187 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 188 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 189 | self.dropout = nn.Dropout(rate=emb_dropout) 190 | 191 | self.transformer_layers = Sequential() 192 | 193 | for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)): 194 | not_last = ind < (len(depth) < 1) 195 | 196 | self.transformer_layers.add(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout)) 197 | 198 | if not_last: 199 | self.transformer_layers.add(Pool(dim)) 200 | dim *= 2 201 | 202 | self.mlp_head = Sequential([ 203 | nn.LayerNormalization(), 204 | nn.Dense(units=num_classes) 205 | ], name='mlp_head') 206 | 207 | def call(self, img, training=True, **kwargs): 208 | x = self.patch_embedding(img) 209 | b, n, d = x.shape 210 | 211 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 212 | x = tf.concat([cls_tokens, x], axis=1) 213 | x += self.pos_embedding[:, :(n + 1)] 214 | x = self.dropout(x, training=training) 215 | 216 | x = self.transformer_layers(x, training=training) 217 | x = self.mlp_head(x[:, 0]) 218 | 219 | return x 220 | 221 | """ Usage 222 | v = PiT( 223 | image_size = 224, 224 | patch_size = 14, 225 | dim = 256, 226 | num_classes = 1000, 227 | depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample 228 | heads = 16, 229 | mlp_dim = 2048, 230 | dropout = 0.1, 231 | emb_dropout = 0.1 232 | ) 233 | img = tf.random.normal(shape=[1, 224, 224, 3]) 234 | preds = v(img) # (1, 1000) 235 | """ 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /vit_tensorflow/regionvit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow import einsum 7 | 8 | from einops import rearrange 9 | from einops.layers.tensorflow import Rearrange, Reduce 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | def cast_tuple(val, length = 1): 18 | return val if isinstance(val, tuple) else ((val,) * length) 19 | 20 | def divisible_by(val, d): 21 | return (val % d) == 0 22 | 23 | def gelu(x, approximate=False): 24 | if approximate: 25 | coeff = tf.cast(0.044715, x.dtype) 26 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 27 | else: 28 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 29 | 30 | class GELU(Layer): 31 | def __init__(self, approximate=False): 32 | super(GELU, self).__init__() 33 | self.approximate = approximate 34 | 35 | def call(self, x, training=True): 36 | return gelu(x, self.approximate) 37 | 38 | class Identity(Layer): 39 | def __init__(self): 40 | super(Identity, self).__init__() 41 | 42 | def call(self, x, training=True): 43 | return tf.identity(x) 44 | 45 | class Downsample(Layer): 46 | def __init__(self, dim): 47 | super(Downsample, self).__init__() 48 | self.conv = nn.Conv2D(filters=dim, kernel_size=3, strides=2, padding='SAME') 49 | 50 | def call(self, x, training=True): 51 | x = self.conv(x) 52 | return x 53 | 54 | class PEG(Layer): 55 | def __init__(self, dim, kernel_size=3): 56 | super(PEG, self).__init__() 57 | self.proj = nn.Conv2D(filters=dim, kernel_size=kernel_size, strides=1, padding='SAME', groups=dim) 58 | 59 | def call(self, x, training=True): 60 | x = self.proj(x) + x 61 | return x 62 | 63 | 64 | class MLP(Layer): 65 | def __init__(self, dim, mult=4, dropout=0.0): 66 | super(MLP, self).__init__() 67 | 68 | self.net = Sequential([ 69 | nn.LayerNormalization(), 70 | nn.Dense(units=dim * mult), 71 | GELU(), 72 | nn.Dropout(rate=dropout), 73 | nn.Dense(units=dim) 74 | ]) 75 | 76 | def call(self, x, training=True): 77 | return self.net(x, training=training) 78 | 79 | class Attention(Layer): 80 | def __init__(self, dim, heads=4, dim_head=32, dropout=0.0): 81 | super(Attention, self).__init__() 82 | 83 | inner_dim = dim_head * heads 84 | self.heads = heads 85 | self.scale = dim_head ** -0.5 86 | 87 | self.norm = nn.LayerNormalization() 88 | self.attend = nn.Softmax() 89 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 90 | 91 | self.to_out = nn.Dense(units=dim) 92 | 93 | def call(self, x, rel_pos_bias=None, training=True): 94 | h = self.heads 95 | 96 | # prenorm 97 | x = self.norm(x) 98 | qkv = self.to_qkv(x) 99 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 100 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 101 | q = q * self.scale 102 | 103 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 104 | 105 | # add relative positional bias for local tokens 106 | if exists(rel_pos_bias): 107 | sim = sim + rel_pos_bias 108 | attn = self.attend(sim) 109 | 110 | # merge heads 111 | 112 | x = einsum('b h i j, b h j d -> b h i d', attn, v) 113 | x = rearrange(x, 'b h n d -> b n (h d)') 114 | x = self.to_out(x) 115 | 116 | return x 117 | 118 | class R2LTransformer(Layer): 119 | def __init__(self, dim, window_size, depth=4, heads=4, dim_head=32, attn_dropout=0.0, ff_dropout=0.0): 120 | super(R2LTransformer, self).__init__() 121 | 122 | self.layers = [] 123 | 124 | self.window_size = window_size 125 | rel_positions = 2 * window_size - 1 126 | self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads) 127 | 128 | for _ in range(depth): 129 | self.layers.append([ 130 | Attention(dim, heads=heads, dim_head=dim_head, dropout=attn_dropout), 131 | MLP(dim, dropout=ff_dropout) 132 | ]) 133 | 134 | 135 | def call(self, local_tokens, region_tokens=None, training=True): 136 | lh, lw = local_tokens.shape[1:3] 137 | rh, rw = region_tokens.shape[1:3] 138 | window_size_h, window_size_w = lh // rh, lw // rw 139 | 140 | local_tokens = rearrange(local_tokens, 'b h w c -> b (h w) c') 141 | region_tokens = rearrange(region_tokens, 'b h w c -> b (h w) c') 142 | 143 | # calculate local relative positional bias 144 | h_range = tf.range(window_size_h) 145 | w_range = tf.range(window_size_w) 146 | 147 | grid_x, grid_y = tf.meshgrid(h_range, w_range, indexing='ij') 148 | grid = tf.stack([grid_x, grid_y]) 149 | grid = rearrange(grid, 'c h w -> c (h w)') 150 | grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1) 151 | 152 | bias_indices = tf.reduce_sum((grid * tf.convert_to_tensor([1, self.window_size * 2 - 1])[:, None, None]), axis=0) 153 | rel_pos_bias = self.local_rel_pos_bias(bias_indices) 154 | rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j') 155 | rel_pos_bias = tf.pad(rel_pos_bias, paddings=[[0, 0], [0, 0], [1, 0], [1, 0]]) 156 | 157 | # go through r2l transformer layers 158 | for attn, ff in self.layers: 159 | region_tokens = attn(region_tokens) + region_tokens 160 | 161 | # concat region tokens to local tokens 162 | 163 | local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h=lh) 164 | local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1=window_size_h, p2=window_size_w) 165 | region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d') 166 | 167 | # do self attention on local tokens, along with its regional token 168 | region_and_local_tokens = tf.concat([region_tokens, local_tokens], axis=1) 169 | region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias=rel_pos_bias) + region_and_local_tokens 170 | 171 | # feedforward 172 | region_and_local_tokens = ff(region_and_local_tokens, training=training) + region_and_local_tokens 173 | 174 | # split back local and regional tokens 175 | region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:] 176 | local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h=lh // window_size_h, w=lw // window_size_w, p1=window_size_h) 177 | region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n=rh * rw) 178 | 179 | local_tokens = rearrange(local_tokens, 'b (h w) c -> b h w c', h=lh, w=lw) 180 | region_tokens = rearrange(region_tokens, 'b (h w) c -> b h w c', h=rh, w=rw) 181 | 182 | return local_tokens, region_tokens 183 | 184 | class RegionViT(Model): 185 | def __init__(self, 186 | dim=(64, 128, 256, 512), 187 | depth=(2, 2, 8, 2), 188 | window_size=7, 189 | num_classes=1000, 190 | tokenize_local_3_conv=False, 191 | local_patch_size=4, 192 | use_peg=False, 193 | attn_dropout=0.0, 194 | ff_dropout=0.0, 195 | ): 196 | super(RegionViT, self).__init__() 197 | dim = cast_tuple(dim, 4) 198 | depth = cast_tuple(depth, 4) 199 | assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4' 200 | assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4' 201 | 202 | self.local_patch_size = local_patch_size 203 | 204 | region_patch_size = local_patch_size * window_size 205 | self.region_patch_size = local_patch_size * window_size 206 | 207 | init_dim, *_, last_dim = dim 208 | 209 | # local and region encoders 210 | if tokenize_local_3_conv: 211 | self.local_encoder = Sequential([ 212 | nn.Conv2D(filters=init_dim, kernel_size=3, strides=2, padding='SAME'), 213 | nn.LayerNormalization(), 214 | GELU(), 215 | nn.Conv2D(filters=init_dim, kernel_size=3, strides=2, padding='SAME'), 216 | nn.LayerNormalization(), 217 | GELU(), 218 | nn.Conv2D(filters=init_dim, kernel_size=3, strides=1, padding='SAME') 219 | ]) 220 | else: 221 | self.local_encoder = nn.Conv2D(filters=init_dim, kernel_size=8, strides=4, padding='SAME') 222 | 223 | self.region_encoder = Sequential([ 224 | Rearrange('b (h p1) (w p2) c -> b h w (c p1 p2) ', p1=region_patch_size, p2=region_patch_size), 225 | nn.Conv2D(filters=init_dim, kernel_size=1, strides=1) 226 | ]) 227 | 228 | # layers 229 | self.region_layers = [] 230 | 231 | for ind, dim, num_layers in zip(range(4), dim, depth): 232 | not_first = ind != 0 233 | need_downsample = not_first 234 | need_peg = not_first and use_peg 235 | 236 | self.region_layers.append([ 237 | Downsample(dim) if need_downsample else Identity(), 238 | PEG(dim) if need_peg else Identity(), 239 | R2LTransformer(dim, depth=num_layers, window_size=window_size, attn_dropout=attn_dropout, ff_dropout=ff_dropout) 240 | ]) 241 | 242 | # final logits 243 | self.to_logits = Sequential([ 244 | Reduce('b h w c -> b c', 'mean'), 245 | nn.LayerNormalization(), 246 | nn.Dense(units=num_classes) 247 | ]) 248 | 249 | def call(self, x, training=True, **kwargs): 250 | _, h, w, _ = x.shape 251 | assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size' 252 | assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size' 253 | 254 | local_tokens = self.local_encoder(x) 255 | region_tokens = self.region_encoder(x) 256 | 257 | for down, peg, transformer in self.region_layers: 258 | local_tokens, region_tokens = down(local_tokens), down(region_tokens) 259 | local_tokens = peg(local_tokens) 260 | local_tokens, region_tokens = transformer(local_tokens, region_tokens, training=training) 261 | 262 | x = self.to_logits(region_tokens) 263 | return x 264 | 265 | """ Usage 266 | v = RegionViT( 267 | dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage 268 | depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage 269 | window_size = 7, # window size, which should be either 7 or 14 270 | num_classes = 1000, # number of output classes 271 | tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models 272 | use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance 273 | ) 274 | 275 | img = tf.random.normal(shape=[1, 224, 224, 3]) 276 | preds = v(img) # (1, 1000) 277 | """ -------------------------------------------------------------------------------- /vit_tensorflow/scalable_vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange 8 | from einops.layers.tensorflow import Reduce 9 | 10 | from functools import partial 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | def pair(t): 19 | return t if isinstance(t, tuple) else (t, t) 20 | 21 | def cast_tuple(val, length = 1): 22 | return val if isinstance(val, tuple) else ((val,) * length) 23 | 24 | def gelu(x, approximate=False): 25 | if approximate: 26 | coeff = tf.cast(0.044715, x.dtype) 27 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 28 | else: 29 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 30 | 31 | class GELU(Layer): 32 | def __init__(self, approximate=False): 33 | super(GELU, self).__init__() 34 | self.approximate = approximate 35 | 36 | def call(self, x, training=True): 37 | return gelu(x, self.approximate) 38 | 39 | class Identity(Layer): 40 | def __init__(self): 41 | super(Identity, self).__init__() 42 | 43 | def call(self, x, training=True): 44 | return tf.identity(x) 45 | 46 | class LayerNorm(Layer): 47 | def __init__(self, dim, eps=1e-5): 48 | super(LayerNorm, self).__init__() 49 | self.eps = eps 50 | 51 | self.g = tf.Variable(tf.ones([1, 1, 1, dim])) 52 | self.b = tf.Variable(tf.zeros([1, 1, 1, dim])) 53 | 54 | def call(self, x, training=True): 55 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 56 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 57 | 58 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b 59 | return x 60 | 61 | class PreNorm(Layer): 62 | def __init__(self, dim, fn): 63 | super(PreNorm, self).__init__() 64 | 65 | self.norm = LayerNorm(dim) 66 | self.fn = fn 67 | 68 | def call(self, x, training=True): 69 | return self.fn(self.norm(x), training=training) 70 | 71 | class Downsample(Layer): 72 | def __init__(self, dim): 73 | super(Downsample, self).__init__() 74 | self.conv = nn.Conv2D(filters=dim, kernel_size=3, strides=2, padding='SAME') 75 | 76 | def call(self, x, training=True): 77 | x = self.conv(x) 78 | return x 79 | 80 | class PEG(Layer): 81 | def __init__(self, dim, kernel_size=3): 82 | super(PEG, self).__init__() 83 | self.proj = nn.Conv2D(filters=dim, kernel_size=kernel_size, strides=1, padding='SAME', groups=dim) 84 | 85 | def call(self, x, training=True): 86 | x = self.proj(x) + x 87 | return x 88 | 89 | class MLP(Layer): 90 | def __init__(self, dim, expansion_factor=4, dropout=0.0): 91 | super(MLP, self).__init__() 92 | inner_dim = dim * expansion_factor 93 | self.net = Sequential([ 94 | nn.Conv2D(filters=inner_dim, kernel_size=1, strides=1), 95 | GELU(), 96 | nn.Dropout(rate=dropout), 97 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 98 | nn.Dropout(rate=dropout) 99 | ]) 100 | 101 | def call(self, x, training=True): 102 | return self.net(x, training=training) 103 | 104 | class ScalableSelfAttention(Layer): 105 | def __init__(self, dim, heads=8, dim_key=32, dim_value=32, dropout=0.0, reduction_factor=1): 106 | super(ScalableSelfAttention, self).__init__() 107 | 108 | self.heads = heads 109 | self.scale = dim_key ** -0.5 110 | self.attend = nn.Softmax() 111 | 112 | self.to_q = nn.Conv2D(filters=dim_key * heads, kernel_size=1, strides=1, use_bias=False) 113 | self.to_k = nn.Conv2D(filters=dim_key * heads, kernel_size=reduction_factor, strides=reduction_factor, use_bias=False) 114 | self.to_v = nn.Conv2D(filters=dim_value * heads, kernel_size=reduction_factor, strides=reduction_factor, use_bias=False) 115 | 116 | self.to_out = Sequential([ 117 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 118 | nn.Dropout(rate=dropout) 119 | ]) 120 | 121 | def call(self, x, training=True): 122 | _, height, width, _ = x.shape 123 | heads = self.heads 124 | 125 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 126 | 127 | # split out heads 128 | q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> b h (...) d', h=heads), (q, k, v)) 129 | 130 | # similarity 131 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 132 | 133 | # attention 134 | attn = self.attend(dots) 135 | 136 | # aggregate values 137 | out = tf.matmul(attn, v) 138 | 139 | # merge back heads 140 | out = rearrange(out, 'b h (x y) d -> b x y (h d)', x=height, y=width) 141 | out = self.to_out(out, training=training) 142 | 143 | return out 144 | 145 | class InteractiveWindowedSelfAttention(Layer): 146 | def __init__(self, dim, window_size, heads=8, dim_key=32, dim_value=32, dropout=0.0): 147 | super(InteractiveWindowedSelfAttention, self).__init__() 148 | 149 | self.heads = heads 150 | self.scale = dim_key ** -0.5 151 | self.window_size = window_size 152 | self.attend = nn.Softmax() 153 | 154 | self.local_interactive_module = nn.Conv2D(filters=dim_value * heads, kernel_size=3, strides=1, padding='SAME') 155 | 156 | self.to_q = nn.Conv2D(filters=dim_key * heads, kernel_size=1, strides=1, use_bias=False) 157 | self.to_k = nn.Conv2D(filters=dim_key * heads, kernel_size=1, strides=1, use_bias=False) 158 | self.to_v = nn.Conv2D(filters=dim_value * heads, kernel_size=1, strides=1, use_bias=False) 159 | 160 | self.to_out = Sequential([ 161 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 162 | nn.Dropout(rate=dropout) 163 | ]) 164 | 165 | def call(self, x, training=True): 166 | _, height, width, _ = x.shape 167 | heads = self.heads 168 | wsz = self.window_size 169 | 170 | wsz_h, wsz_w = default(wsz, height), default(wsz, width) 171 | assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})' 172 | 173 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 174 | 175 | # get output of LIM 176 | local_out = self.local_interactive_module(v) 177 | 178 | # divide into window (and split out heads) for efficient self attention 179 | q, k, v = map(lambda t: rearrange(t, 'b (x w1) (y w2) (h d) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v)) 180 | 181 | # similarity 182 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 183 | 184 | # attention 185 | attn = self.attend(dots) 186 | 187 | # aggregate values 188 | out = tf.matmul(attn, v) 189 | 190 | # reshape the windows back to full feature map (and merge heads) 191 | out = rearrange(out, '(b x y) h (w1 w2) d -> b (x w1) (y w2) (h d)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w) 192 | 193 | # add LIM output 194 | out = out + local_out 195 | 196 | out = self.to_out(out, training=training) 197 | 198 | return out 199 | 200 | class Transformer(Layer): 201 | def __init__(self, 202 | dim, 203 | depth, 204 | heads=8, 205 | ff_expansion_factor=4, 206 | dropout=0., 207 | ssa_dim_key=32, 208 | ssa_dim_value=32, 209 | ssa_reduction_factor=1, 210 | iwsa_dim_key=32, 211 | iwsa_dim_value=32, 212 | iwsa_window_size=None, 213 | norm_output=True 214 | ): 215 | super(Transformer, self).__init__() 216 | 217 | self.layers = [] 218 | 219 | for ind in range(depth): 220 | is_first = ind == 0 221 | 222 | self.layers.append([ 223 | PreNorm(dim, ScalableSelfAttention(dim, heads=heads, dim_key=ssa_dim_key, dim_value=ssa_dim_value, 224 | reduction_factor=ssa_reduction_factor, dropout=dropout)), 225 | PreNorm(dim, MLP(dim, expansion_factor=ff_expansion_factor, dropout=dropout)), 226 | PEG(dim) if is_first else None, 227 | PreNorm(dim, MLP(dim, expansion_factor=ff_expansion_factor, dropout=dropout)), 228 | PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads=heads, dim_key=iwsa_dim_key, dim_value=iwsa_dim_value, 229 | window_size=iwsa_window_size, 230 | dropout=dropout)) 231 | ]) 232 | 233 | self.norm = LayerNorm(dim) if norm_output else Identity() 234 | 235 | def call(self, x, training=True): 236 | for ssa, ff1, peg, iwsa, ff2 in self.layers: 237 | x = ssa(x, training=training) + x 238 | x = ff1(x, training=training) + x 239 | 240 | if exists(peg): 241 | x = peg(x) 242 | 243 | x = iwsa(x, training=training) + x 244 | x = ff2(x, training=training) + x 245 | 246 | x = self.norm(x) 247 | 248 | return x 249 | 250 | class ScalableViT(Model): 251 | def __init__(self, 252 | num_classes, 253 | dim, 254 | depth, 255 | heads, 256 | reduction_factor, 257 | window_size=None, 258 | iwsa_dim_key=32, 259 | iwsa_dim_value=32, 260 | ssa_dim_key=32, 261 | ssa_dim_value=32, 262 | ff_expansion_factor=4, 263 | channels=3, 264 | dropout=0.0 265 | ): 266 | super(ScalableViT, self).__init__() 267 | 268 | self.to_patches = nn.Conv2D(filters=dim, kernel_size=7, strides=4, padding='SAME') 269 | 270 | assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' 271 | 272 | num_stages = len(depth) 273 | dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages))) 274 | 275 | hyperparams_per_stage = [ 276 | heads, 277 | ssa_dim_key, 278 | ssa_dim_value, 279 | reduction_factor, 280 | iwsa_dim_key, 281 | iwsa_dim_value, 282 | window_size, 283 | ] 284 | 285 | hyperparams_per_stage = list(map(partial(cast_tuple, length=num_stages), hyperparams_per_stage)) 286 | assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage))) 287 | 288 | self.scalable_layers = [] 289 | 290 | for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)): 291 | is_last = ind == (num_stages - 1) 292 | 293 | self.scalable_layers.append([ 294 | Transformer(dim=layer_dim, depth=layer_depth, heads=layer_heads, 295 | ff_expansion_factor=ff_expansion_factor, dropout=dropout, ssa_dim_key=layer_ssa_dim_key, 296 | ssa_dim_value=layer_ssa_dim_value, ssa_reduction_factor=layer_ssa_reduction_factor, 297 | iwsa_dim_key=layer_iwsa_dim_key, iwsa_dim_value=layer_iwsa_dim_value, 298 | iwsa_window_size=layer_window_size), 299 | Downsample(layer_dim * 2) if not is_last else None 300 | ]) 301 | 302 | self.mlp_head = Sequential([ 303 | Reduce('b h w d-> b d', 'mean'), 304 | nn.LayerNormalization(), 305 | nn.Dense(units=num_classes) 306 | ]) 307 | 308 | def call(self, img, training=True, **kwargs): 309 | x = self.to_patches(img) 310 | 311 | for transformer, downsample in self.scalable_layers: 312 | x = transformer(x, training=training) 313 | 314 | if exists(downsample): 315 | x = downsample(x) 316 | 317 | x = self.mlp_head(x) 318 | 319 | return x 320 | 321 | """ Usage 322 | v = ScalableViT( 323 | num_classes = 1000, 324 | dim = 64, # starting model dimension. at every stage, dimension is doubled 325 | heads = (2, 4, 8, 16), # number of attention heads at each stage 326 | depth = (2, 2, 20, 2), # number of transformer blocks at each stage 327 | ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key) 328 | reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2) 329 | window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed 330 | dropout = 0.1, # attention and feedforward dropout 331 | ) 332 | 333 | img = tf.random.normal(shape=[1, 256, 256, 3]) 334 | preds = v(img) # (1, 1000) 335 | """ 336 | -------------------------------------------------------------------------------- /vit_tensorflow/simmim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | import tensorflow.keras.layers as nn 4 | 5 | from einops import repeat 6 | import numpy as np 7 | from vit import ViT 8 | 9 | def scatter_numpy(x, dim, index, src): 10 | """ 11 | Writes all values from the Tensor src into x at the indices specified in the index Tensor. 12 | 13 | :param dim: The axis along which to index 14 | :param index: The indices of elements to scatter 15 | :param src: The source element(s) to scatter 16 | :return: x 17 | """ 18 | 19 | if x.ndim != index.ndim: 20 | raise ValueError("Index should have the same number of dimensions as output") 21 | if dim >= x.ndim or dim < -x.ndim: 22 | raise IndexError("dim is out of range") 23 | if dim < 0: 24 | # Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter 25 | dim = x.ndim + dim 26 | idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 27 | self_xsection_shape = x.shape[:dim] + x.shape[dim + 1:] 28 | if idx_xsection_shape != self_xsection_shape: 29 | raise ValueError("Except for dimension " + str(dim) + 30 | ", all dimensions of index and output should be the same size") 31 | if (index >= x.shape[dim]).any() or (index < 0).any(): 32 | raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)") 33 | 34 | def make_slice(arr, dim, i): 35 | slc = [slice(None)] * arr.ndim 36 | slc[dim] = i 37 | slc = tuple(slc) 38 | return slc 39 | 40 | # We use index and dim parameters to create idx 41 | # idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self 42 | idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1), 43 | index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])] 44 | idx = list(np.concatenate(idx, axis=1)) 45 | idx.insert(dim, idx.pop()) 46 | 47 | if not np.isscalar(src): 48 | if index.shape[dim] > src.shape[dim]: 49 | raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ") 50 | src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:] 51 | if idx_xsection_shape != src_xsection_shape: 52 | raise ValueError("Except for dimension " + 53 | str(dim) + ", all dimensions of index and src should be the same size") 54 | # src_idx is a NumPy advanced index for indexing of elements in the src 55 | src_idx = list(idx) 56 | src_idx.pop(dim) 57 | src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape))) 58 | idx = tuple(idx) 59 | x[idx] = src[src_idx] 60 | 61 | else: 62 | idx = tuple(idx) 63 | x[idx] = src 64 | 65 | return x 66 | 67 | class SimMIM(Model): 68 | def __init__(self, image_size, encoder, masking_ratio=0.5): 69 | super(SimMIM, self).__init__() 70 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 71 | self.masking_ratio = masking_ratio 72 | 73 | # build 74 | encoder.build(input_shape=(1, image_size, image_size, 3)) 75 | 76 | # extract some hyperparameters and functions from encoder (vision transformer to be trained) 77 | self.encoder = encoder 78 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 79 | self.to_patch, self.patch_to_emb = encoder.patch_embedding.layers[:2] 80 | pixel_values_per_patch = self.patch_to_emb.weights[0].shape[0] 81 | 82 | # simpler linear head 83 | self.mask_token = tf.Variable(tf.random.normal([encoder_dim])) 84 | self.to_pixels = nn.Dense(units=pixel_values_per_patch) 85 | 86 | def call(self, img, training=True, **kwargs): 87 | # get patches 88 | patches = self.to_patch(img, training=training) 89 | batch, num_patches, *_ = patches.shape 90 | 91 | # for indexing purposes 92 | batch_range = tf.range(batch)[:, None] 93 | 94 | # get positions 95 | pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)] 96 | 97 | # patch to encoder tokens and add positions 98 | tokens = self.patch_to_emb(patches, training=training) 99 | tokens = tokens + pos_emb 100 | 101 | # prepare mask tokens 102 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_patches) 103 | mask_tokens = mask_tokens + pos_emb 104 | 105 | # calculate of patches needed to be masked, and get positions (indices) to be masked 106 | num_masked = int(self.masking_ratio * num_patches) 107 | 108 | masked_indices = tf.math.top_k(tf.random.uniform(shape=[batch, num_patches]), k=num_masked).indices 109 | masked_bool_mask = scatter_numpy(np.zeros(shape=[batch, num_patches]), dim=-1, index=masked_indices.numpy(), src=1) 110 | masked_bool_mask = tf.cast(masked_bool_mask, tf.bool) 111 | 112 | # mask tokens 113 | tokens = tf.where(masked_bool_mask[..., None], mask_tokens, tokens) 114 | 115 | # attend with vision transformer 116 | encoded = self.encoder.transformer(tokens, training=training) 117 | 118 | # get the masked tokens 119 | encoded_mask_tokens = encoded.numpy()[batch_range, masked_indices] 120 | 121 | # small linear projection for predicted pixel values 122 | pred_pixel_values = self.to_pixels(encoded_mask_tokens, training=training) 123 | 124 | # get the masked patches for the final reconstruction loss 125 | masked_patches = patches.numpy()[batch_range, masked_indices] 126 | 127 | # calculate reconstruction loss 128 | recon_loss = tf.reduce_mean(tf.abs(pred_pixel_values - masked_patches)) / num_masked 129 | 130 | return recon_loss 131 | 132 | """ Usage 133 | v = ViT( 134 | image_size = 256, 135 | patch_size = 32, 136 | num_classes = 1000, 137 | dim = 1024, 138 | depth = 6, 139 | heads = 8, 140 | mlp_dim = 2048 141 | ) 142 | 143 | mim = SimMIM( 144 | image_size = 256, 145 | encoder = v, 146 | masking_ratio = 0.5 # they found 50% to yield the best results 147 | ) 148 | 149 | img = tf.random.normal(shape=[8, 256, 256, 3]) 150 | loss = mim(img) # (8, 1000) 151 | """ 152 | -------------------------------------------------------------------------------- /vit_tensorflow/t2t.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from vit import Transformer 7 | 8 | from einops import rearrange, repeat 9 | import math 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def conv_output_size(image_size, kernel_size, stride, padding): 15 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 16 | 17 | class RearrangeUnfoldTransformer(Layer): 18 | def __init__(self, is_first, is_last, kernel_size, stride, 19 | dim, heads, depth, dim_head, mlp_dim, dropout): 20 | super(RearrangeUnfoldTransformer, self).__init__() 21 | self.is_first = is_first 22 | self.is_last = is_last 23 | self.kernel_size = [1, kernel_size, kernel_size, 1] 24 | self.stride = [1, stride, stride, 1] 25 | self.rates = [1, 1, 1, 1] 26 | 27 | # transformer 28 | self.dim = dim 29 | self.heads = heads 30 | self.depth = depth 31 | self.dim_head = dim_head 32 | self.mlp_dim = mlp_dim 33 | self.dropout = dropout 34 | 35 | if not self.is_last: 36 | self.transformer_layer = Transformer(dim=self.dim, heads=self.heads, depth=self.depth, dim_head=self.dim_head, mlp_dim=self.mlp_dim, dropout=self.dropout) 37 | 38 | 39 | def call(self, x, training=True): 40 | if not self.is_first: 41 | x = rearrange(x, 'b (h w) c -> b h w c', h=int(math.sqrt(x.shape[1]))) 42 | x = tf.image.extract_patches(x, sizes=self.kernel_size, strides=self.stride, rates=self.rates, padding='SAME') 43 | x = rearrange(x, 'b h w c -> b (h w) c') 44 | if not self.is_last: 45 | x = self.transformer_layer(x, training=training) 46 | 47 | return x 48 | 49 | class T2TViT(Model): 50 | def __init__(self, image_size, num_classes, dim, 51 | depth=None, heads=None, mlp_dim=None, pool='cls', channels=3, dim_head=64, dropout=0.0, emb_dropout=0.0, 52 | transformer=None, t2t_layers=((7, 4), (3, 2), (3, 2))): 53 | super(T2TViT, self).__init__() 54 | 55 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 56 | 57 | layers = Sequential() 58 | layer_dim = channels 59 | output_image_size = image_size 60 | 61 | for i, (kernel_size, stride) in enumerate(t2t_layers): 62 | layer_dim *= kernel_size ** 2 63 | is_first = i == 0 64 | is_last = i == (len(t2t_layers) - 1) 65 | output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) 66 | 67 | layers.add(RearrangeUnfoldTransformer(is_first, is_last, kernel_size, stride, 68 | dim=layer_dim, heads=1, depth=1, dim_head=layer_dim, mlp_dim=layer_dim, dropout=dropout) 69 | ) 70 | 71 | layers.add(nn.Dense(units=dim)) 72 | self.patch_embedding = layers 73 | 74 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, output_image_size ** 2 + 1, dim])) 75 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 76 | self.dropout = nn.Dropout(rate=emb_dropout) 77 | 78 | if not exists(transformer): 79 | assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied' 80 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 81 | else: 82 | self.transformer = transformer 83 | 84 | self.pool = pool 85 | 86 | self.mlp_head = Sequential([ 87 | nn.LayerNormalization(), 88 | nn.Dense(units=num_classes) 89 | ], name='mlp_head') 90 | 91 | def call(self, img, training=True, **kwargs): 92 | x = self.patch_embedding(img) 93 | b, n, d = x.shape 94 | 95 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 96 | x = tf.concat([cls_tokens, x], axis=1) 97 | x += self.pos_embedding[:, :(n + 1)] 98 | x = self.dropout(x, training=training) 99 | 100 | x = self.transformer(x, training=training) 101 | 102 | if self.pool == 'mean': 103 | x = tf.reduce_mean(x, axis=1) 104 | else: 105 | x = x[:, 0] 106 | 107 | x = self.mlp_head(x) 108 | 109 | return x 110 | """ Usage 111 | v = T2TViT( 112 | dim = 512, 113 | image_size = 224, 114 | depth = 5, 115 | heads = 8, 116 | mlp_dim = 512, 117 | num_classes = 1000, 118 | t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module 119 | ) 120 | 121 | img = tf.random.normal(shape=[1, 224, 224, 3]) 122 | preds = v(img) # (1, 1000) 123 | """ -------------------------------------------------------------------------------- /vit_tensorflow/twins_svt.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import einsum 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.layers import Layer 5 | from tensorflow.keras import Sequential 6 | import tensorflow.keras.layers as nn 7 | 8 | from einops import rearrange 9 | 10 | def group_dict_by_key(cond, d): 11 | return_val = [dict(), dict()] 12 | for key in d.keys(): 13 | match = bool(cond(key)) 14 | ind = int(not match) 15 | return_val[ind][key] = d[key] 16 | return (*return_val,) 17 | 18 | def group_by_key_prefix_and_remove_prefix(prefix, d): 19 | kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d) 20 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 21 | return kwargs_without_prefix, kwargs 22 | 23 | def gelu(x, approximate=False): 24 | if approximate: 25 | coeff = tf.cast(0.044715, x.dtype) 26 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 27 | else: 28 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 29 | 30 | class Identity(Layer): 31 | def __init__(self): 32 | super(Identity, self).__init__() 33 | 34 | def call(self, x, training=True): 35 | return tf.identity(x) 36 | 37 | class Residual(Layer): 38 | def __init__(self, fn): 39 | super(Residual, self).__init__() 40 | self.fn = fn 41 | 42 | def call(self, x, training=True): 43 | return self.fn(x, training=training) + x 44 | 45 | class LayerNorm(Layer): # layernorm, but done in the channel dimension #1 46 | def __init__(self, dim, eps=1e-5): 47 | super(LayerNorm, self).__init__() 48 | self.eps = eps 49 | 50 | self.g = tf.Variable(tf.ones([1, 1, 1, dim])) 51 | self.b = tf.Variable(tf.zeros([1, 1, 1, dim])) 52 | 53 | def call(self, x, training=True): 54 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 55 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 56 | 57 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b 58 | return x 59 | 60 | class PreNorm(Layer): 61 | def __init__(self, dim, fn): 62 | super(PreNorm, self).__init__() 63 | 64 | self.norm = LayerNorm(dim) 65 | self.fn = fn 66 | 67 | def call(self, x, training=True): 68 | return self.fn(self.norm(x), training=training) 69 | 70 | class GELU(Layer): 71 | def __init__(self, approximate=False): 72 | super(GELU, self).__init__() 73 | self.approximate = approximate 74 | 75 | def call(self, x, training=True): 76 | return gelu(x, self.approximate) 77 | 78 | class MLP(Layer): 79 | def __init__(self, dim, mult=4, dropout=0.0): 80 | super(MLP, self).__init__() 81 | 82 | self.net = [ 83 | nn.Conv2D(filters=dim * mult, kernel_size=1, strides=1), 84 | GELU(), 85 | nn.Dropout(rate=dropout), 86 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 87 | nn.Dropout(rate=dropout) 88 | ] 89 | self.net = Sequential(self.net) 90 | 91 | def call(self, x, training=True): 92 | return self.net(x, training=training) 93 | 94 | class PatchEmbedding(Layer): 95 | def __init__(self, dim_out, patch_size): 96 | super(PatchEmbedding, self).__init__() 97 | self.dim_out = dim_out 98 | self.patch_size = patch_size 99 | self.proj = nn.Conv2D(filters=dim_out, kernel_size=1, strides=1) 100 | 101 | def call(self, fmap, training=True): 102 | p = self.patch_size 103 | fmap = rearrange(fmap, 'b (h p1) (w p2) c -> b h w (c p1 p2)', p1 = p, p2 = p) 104 | x = self.proj(fmap) 105 | 106 | return x 107 | 108 | class PEG(Layer): 109 | def __init__(self, dim, kernel_size=3): 110 | super(PEG, self).__init__() 111 | self.proj = Residual(nn.Conv2D(filters=dim, kernel_size=kernel_size, strides=1, padding='SAME', groups=dim)) 112 | 113 | def call(self, x, training=True): 114 | x = self.proj(x) 115 | return x 116 | 117 | class LocalAttention(Layer): 118 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, patch_size=7): 119 | super(LocalAttention, self).__init__() 120 | inner_dim = dim_head * heads 121 | self.patch_size = patch_size 122 | self.heads = heads 123 | self.scale = dim_head ** -0.5 124 | 125 | self.attend = nn.Softmax() 126 | 127 | self.to_q = nn.Conv2D(filters=inner_dim, kernel_size=1, strides=1, use_bias=False) 128 | self.to_kv = nn.Conv2D(filters=inner_dim * 2, kernel_size=1, strides=1, use_bias=False) 129 | 130 | self.to_out = Sequential([ 131 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 132 | nn.Dropout(rate=dropout) 133 | ]) 134 | 135 | def call(self, fmap, training=True): 136 | b, x, y, n = fmap.shape 137 | h = self.heads 138 | p = self.patch_size 139 | x, y = map(lambda t: t // p, (x, y)) 140 | 141 | fmap = rearrange(fmap, 'b (x p1) (y p2) c -> (b x y) p1 p2 c', p1=p, p2=p) 142 | q = self.to_q(fmap) 143 | kv = self.to_kv(fmap) 144 | k, v = tf.split(kv, num_or_size_splits=2, axis=-1) 145 | 146 | q, k, v = map(lambda t: rearrange(t, 'b p1 p2 (h d) -> (b h) (p1 p2) d', h=h), (q, k, v)) 147 | 148 | dots = einsum('b i d, b j d -> b i j', q, k) * self.scale 149 | 150 | attn = self.attend(dots) 151 | 152 | out = einsum('b i j, b j d -> b i d', attn, v) 153 | out = rearrange(out, '(b x y h) (p1 p2) d -> b (x p1) (y p2) (h d) ', h=h, x=x, y=y, p1=p, p2=p) 154 | out = self.to_out(out, training=training) 155 | 156 | return out 157 | 158 | class GlobalAttention(Layer): 159 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, k=7): 160 | super(GlobalAttention, self).__init__() 161 | inner_dim = dim_head * heads 162 | self.heads = heads 163 | self.scale = dim_head ** -0.5 164 | 165 | self.attend = nn.Softmax() 166 | 167 | self.to_q = nn.Conv2D(filters=inner_dim, kernel_size=1, use_bias=False) 168 | self.to_kv = nn.Conv2D(filters=inner_dim * 2, kernel_size=k, strides=k, use_bias=False) 169 | 170 | self.to_out = Sequential([ 171 | nn.Conv2D(filters=dim, kernel_size=1, strides=1), 172 | nn.Dropout(rate=dropout) 173 | ]) 174 | 175 | def call(self, x, training=True): 176 | b, _, y, n = x.shape 177 | h = self.heads 178 | 179 | q = self.to_q(x) 180 | kv = self.to_kv(x) 181 | k, v = tf.split(kv, num_or_size_splits=2, axis=-1) 182 | q, k, v = map(lambda t: rearrange(t, 'b x y (h d) -> (b h) (x y) d', h=h), (q, k, v)) 183 | 184 | dots = einsum('b i d, b j d -> b i j', q, k) * self.scale 185 | 186 | attn = self.attend(dots) 187 | out = einsum('b i j, b j d -> b i d', attn, v) 188 | out = rearrange(out, '(b h) (x y) d -> b x y (h d)', h=h, y=y) 189 | out = self.to_out(out, training=training) 190 | return out 191 | 192 | class Transformer(Layer): 193 | def __init__(self, dim, depth, heads=8, dim_head=64, mlp_mult=4, local_patch_size=7, global_k=7, dropout=0.0, has_local=True): 194 | super(Transformer, self).__init__() 195 | 196 | self.layers = [] 197 | 198 | for _ in range(depth): 199 | self.layers.append([ 200 | Residual(PreNorm(dim, LocalAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout, patch_size=local_patch_size))) if has_local else Identity(), 201 | Residual(PreNorm(dim, MLP(dim, mlp_mult, dropout=dropout))) if has_local else Identity(), 202 | Residual(PreNorm(dim, GlobalAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout, k=global_k))), 203 | Residual(PreNorm(dim, MLP(dim, mlp_mult, dropout=dropout))) 204 | ]) 205 | 206 | def call(self, x, training=True): 207 | for local_attn, ff1, global_attn, ff2 in self.layers: 208 | x = local_attn(x, training=training) 209 | x = ff1(x, training=training) 210 | x = global_attn(x, training=training) 211 | x = ff2(x, training=training) 212 | 213 | return x 214 | 215 | class TwinsSVT(Model): 216 | def __init__(self, 217 | num_classes, 218 | s1_emb_dim=64, 219 | s1_patch_size=4, 220 | s1_local_patch_size=7, 221 | s1_global_k=7, 222 | s1_depth=1, 223 | s2_emb_dim=128, 224 | s2_patch_size=2, 225 | s2_local_patch_size=7, 226 | s2_global_k=7, 227 | s2_depth=1, 228 | s3_emb_dim=256, 229 | s3_patch_size=2, 230 | s3_local_patch_size=7, 231 | s3_global_k=7, 232 | s3_depth=5, 233 | s4_emb_dim=512, 234 | s4_patch_size=2, 235 | s4_local_patch_size=7, 236 | s4_global_k=7, 237 | s4_depth=4, 238 | peg_kernel_size=3, 239 | dropout=0.0 240 | ): 241 | super(TwinsSVT, self).__init__() 242 | kwargs = dict(locals()) 243 | 244 | self.svt_layers = Sequential() 245 | 246 | for prefix in ('s1', 's2', 's3', 's4'): 247 | config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs) 248 | is_last = prefix == 's4' 249 | 250 | dim_next = config['emb_dim'] 251 | 252 | self.svt_layers.add(Sequential([ 253 | PatchEmbedding(dim_out=dim_next, patch_size=config['patch_size']), 254 | Transformer(dim=dim_next, depth=1, local_patch_size=config['local_patch_size'], 255 | global_k=config['global_k'], dropout=dropout, has_local=not is_last), 256 | PEG(dim=dim_next, kernel_size=peg_kernel_size), 257 | Transformer(dim=dim_next, depth=config['depth'], local_patch_size=config['local_patch_size'], 258 | global_k=config['global_k'], dropout=dropout, has_local=not is_last) 259 | ])) 260 | 261 | self.svt_layers.add(Sequential([ 262 | nn.GlobalAvgPool2D(), 263 | nn.Dense(units=num_classes) 264 | ])) 265 | 266 | def call(self, x, training=True, **kwargs): 267 | x = self.svt_layers(x, training=training) 268 | return x 269 | 270 | """ Usage 271 | v = TwinsSVT( 272 | num_classes = 1000, # number of output classes 273 | s1_emb_dim = 64, # stage 1 - patch embedding projected dimension 274 | s1_patch_size = 4, # stage 1 - patch size for patch embedding 275 | s1_local_patch_size = 7, # stage 1 - patch size for local attention 276 | s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper 277 | s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff) 278 | s2_emb_dim = 128, # stage 2 (same as above) 279 | s2_patch_size = 2, 280 | s2_local_patch_size = 7, 281 | s2_global_k = 7, 282 | s2_depth = 1, 283 | s3_emb_dim = 256, # stage 3 (same as above) 284 | s3_patch_size = 2, 285 | s3_local_patch_size = 7, 286 | s3_global_k = 7, 287 | s3_depth = 5, 288 | s4_emb_dim = 512, # stage 4 (same as above) 289 | s4_patch_size = 2, 290 | s4_local_patch_size = 7, 291 | s4_global_k = 7, 292 | s4_depth = 4, 293 | peg_kernel_size = 3, # positional encoding generator kernel size 294 | dropout = 0. # dropout 295 | ) 296 | 297 | img = tf.random.normal(shape=[1, 224, 224, 3]) 298 | preds = v(img) # (1, 1000) 299 | """ -------------------------------------------------------------------------------- /vit_tensorflow/vit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from tensorflow import einsum 8 | from einops import rearrange, repeat 9 | from einops.layers.tensorflow import Rearrange 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | class PreNorm(Layer): 15 | def __init__(self, fn): 16 | super(PreNorm, self).__init__() 17 | 18 | self.norm = nn.LayerNormalization() 19 | self.fn = fn 20 | 21 | def call(self, x, training=True): 22 | return self.fn(self.norm(x), training=training) 23 | 24 | class MLP(Layer): 25 | def __init__(self, dim, hidden_dim, dropout=0.0): 26 | super(MLP, self).__init__() 27 | 28 | def GELU(): 29 | def gelu(x, approximate=False): 30 | if approximate: 31 | coeff = tf.cast(0.044715, x.dtype) 32 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 33 | else: 34 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 35 | 36 | return nn.Activation(gelu) 37 | 38 | self.net = Sequential([ 39 | nn.Dense(units=hidden_dim), 40 | GELU(), 41 | nn.Dropout(rate=dropout), 42 | nn.Dense(units=dim), 43 | nn.Dropout(rate=dropout) 44 | ]) 45 | 46 | def call(self, x, training=True): 47 | return self.net(x, training=training) 48 | 49 | class Attention(Layer): 50 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 51 | super(Attention, self).__init__() 52 | inner_dim = dim_head * heads 53 | project_out = not (heads == 1 and dim_head == dim) 54 | 55 | self.heads = heads 56 | self.scale = dim_head ** -0.5 57 | 58 | self.attend = nn.Softmax() 59 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 60 | 61 | if project_out: 62 | self.to_out = [ 63 | nn.Dense(units=dim), 64 | nn.Dropout(rate=dropout) 65 | ] 66 | else: 67 | self.to_out = [] 68 | 69 | self.to_out = Sequential(self.to_out) 70 | 71 | def call(self, x, training=True): 72 | qkv = self.to_qkv(x) 73 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 74 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 75 | 76 | # dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 77 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 78 | attn = self.attend(dots) 79 | 80 | # x = tf.matmul(attn, v) 81 | x = einsum('b h i j, b h j d -> b h i d', attn, v) 82 | x = rearrange(x, 'b h n d -> b n (h d)') 83 | x = self.to_out(x, training=training) 84 | 85 | return x 86 | 87 | class Transformer(Layer): 88 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 89 | super(Transformer, self).__init__() 90 | 91 | self.layers = [] 92 | 93 | for _ in range(depth): 94 | self.layers.append([ 95 | PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 96 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 97 | ]) 98 | 99 | def call(self, x, training=True): 100 | for attn, mlp in self.layers: 101 | x = attn(x, training=training) + x 102 | x = mlp(x, training=training) + x 103 | 104 | return x 105 | 106 | class ViT(Model): 107 | def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, 108 | pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0): 109 | """ 110 | image_size: int. 111 | -> Image size. If you have rectangular images, make sure your image size is the maximum of the width and height 112 | patch_size: int. 113 | -> Number of patches. image_size must be divisible by patch_size. 114 | -> The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16. 115 | num_classes: int. 116 | -> Number of classes to classify. 117 | dim: int. 118 | -> Last dimension of output tensor after linear transformation nn.Linear(..., dim). 119 | depth: int. 120 | -> Number of Transformer blocks. 121 | heads: int. 122 | -> Number of heads in Multi-head Attention layer. 123 | mlp_dim: int. 124 | -> Dimension of the MLP (FeedForward) layer. 125 | dropout: float between [0, 1], default 0.. 126 | -> Dropout rate. 127 | emb_dropout: float between [0, 1], default 0. 128 | -> Embedding dropout rate. 129 | pool: string, either cls token pooling or mean pooling 130 | """ 131 | super(ViT, self).__init__() 132 | 133 | image_height, image_width = pair(image_size) 134 | patch_height, patch_width = pair(patch_size) 135 | 136 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 137 | 138 | num_patches = (image_height // patch_height) * (image_width // patch_width) 139 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 140 | 141 | self.patch_embedding = Sequential([ 142 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 143 | nn.Dense(units=dim) 144 | ], name='patch_embedding') 145 | 146 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 147 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 148 | self.dropout = nn.Dropout(rate=emb_dropout) 149 | 150 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 151 | 152 | self.pool = pool 153 | 154 | self.mlp_head = Sequential([ 155 | nn.LayerNormalization(), 156 | nn.Dense(units=num_classes) 157 | ], name='mlp_head') 158 | 159 | def call(self, img, training=True, **kwargs): 160 | x = self.patch_embedding(img) 161 | b, n, d = x.shape 162 | 163 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 164 | x = tf.concat([cls_tokens, x], axis=1) 165 | x += self.pos_embedding[:, :(n + 1)] 166 | x = self.dropout(x, training=training) 167 | 168 | x = self.transformer(x, training=training) 169 | 170 | if self.pool == 'mean': 171 | x = tf.reduce_mean(x, axis=1) 172 | else: 173 | x = x[:, 0] 174 | 175 | x = self.mlp_head(x) 176 | 177 | return x 178 | 179 | """ Usage 180 | 181 | v = ViT( 182 | image_size = 256, 183 | patch_size = 32, 184 | num_classes = 1000, 185 | dim = 1024, 186 | depth = 6, 187 | heads = 16, 188 | mlp_dim = 2048, 189 | dropout = 0.1, 190 | emb_dropout = 0.1 191 | ) 192 | 193 | img = tf.random.normal(shape=[1, 256, 256, 3]) 194 | preds = v(img) # (1, 1000) 195 | 196 | """ -------------------------------------------------------------------------------- /vit_tensorflow/vit_for_small_dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.tensorflow import Rearrange 9 | 10 | import numpy as np 11 | 12 | def pair(t): 13 | return t if isinstance(t, tuple) else (t, t) 14 | 15 | def shift(x): 16 | b, h, w, c = x.shape 17 | shifted_x = [] 18 | 19 | shifts = [1, -1] # [shift, axis] 20 | 21 | # width 22 | z = tf.zeros([b, h, 1, c], dtype=tf.float32) 23 | for idx, shift in enumerate(shifts): 24 | if idx == 0: 25 | s = tf.roll(x, shift, axis=2)[:, :, shift:, :] 26 | concat = tf.concat([z, s], axis=2) 27 | 28 | 29 | else: 30 | s = tf.roll(x, shift, axis=2)[:, :, :shift, :] 31 | concat = tf.concat([s, z], axis=2) 32 | 33 | shifted_x.append(concat) 34 | 35 | # height 36 | z = tf.zeros([b, 1, w, c], dtype=tf.float32) 37 | for idx, shift in enumerate(shifts): 38 | if idx == 0: 39 | s = tf.roll(x, shift, axis=1)[:, shift:, :, :] 40 | concat = tf.concat([z, s], axis=1) 41 | else: 42 | s = tf.roll(x, shift, axis=1)[:, :shift, :, :] 43 | concat = tf.concat([s, z], axis=1) 44 | 45 | shifted_x.append(concat) 46 | 47 | return shifted_x 48 | 49 | def gelu(x, approximate=False): 50 | if approximate: 51 | coeff = tf.cast(0.044715, x.dtype) 52 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 53 | else: 54 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 55 | 56 | class GELU(Layer): 57 | def __init__(self, approximate=False): 58 | super(GELU, self).__init__() 59 | self.approximate = approximate 60 | 61 | def call(self, x, training=True): 62 | return gelu(x, self.approximate) 63 | 64 | class PreNorm(Layer): 65 | def __init__(self, fn): 66 | super(PreNorm, self).__init__() 67 | 68 | self.norm = nn.LayerNormalization() 69 | self.fn = fn 70 | 71 | def call(self, x, **kwargs): 72 | return self.fn(self.norm(x), **kwargs) 73 | 74 | class MLP(Layer): 75 | def __init__(self, dim, hidden_dim, dropout=0.0): 76 | super(MLP, self).__init__() 77 | self.net = Sequential([ 78 | nn.Dense(units=hidden_dim), 79 | GELU(), 80 | nn.Dropout(rate=dropout), 81 | nn.Dense(units=dim), 82 | nn.Dropout(rate=dropout) 83 | ]) 84 | 85 | def call(self, x, training=True): 86 | return self.net(x, training=training) 87 | 88 | class LSA(Layer): 89 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 90 | super(LSA, self).__init__() 91 | 92 | inner_dim = dim_head * heads 93 | self.heads = heads 94 | self.temperature = tf.Variable(tf.math.log(dim_head ** -0.5)) 95 | 96 | self.attend = nn.Softmax() 97 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 98 | 99 | self.to_out = Sequential([ 100 | nn.Dense(units=dim), 101 | nn.Dropout(rate=dropout) 102 | ]) 103 | 104 | def call(self, x, training=True): 105 | qkv = self.to_qkv(x) 106 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 107 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 108 | 109 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * tf.math.exp(self.temperature) 110 | 111 | mask = tf.eye(dots.shape[-1], dtype=tf.bool) 112 | mask_value = -np.finfo(dots.dtype.as_numpy_dtype).max 113 | dots = tf.where(mask, mask_value, dots) 114 | 115 | attn = self.attend(dots) 116 | 117 | out = tf.matmul(attn, v) 118 | out = rearrange(out, 'b h n d -> b n (h d)') 119 | out = self.to_out(out, training=training) 120 | 121 | return out 122 | 123 | class Transformer(Layer): 124 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 125 | super(Transformer, self).__init__() 126 | 127 | self.layers = [] 128 | 129 | for _ in range(depth): 130 | self.layers.append([ 131 | PreNorm(LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 132 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 133 | ]) 134 | 135 | def call(self, x, training=True): 136 | for attn, ff in self.layers: 137 | x = attn(x, training=training) + x 138 | x = ff(x, training=training) + x 139 | 140 | return x 141 | 142 | class SPT(Layer): 143 | def __init__(self, dim, patch_size): 144 | super(SPT, self).__init__() 145 | 146 | self.to_patch_tokens = Sequential([ 147 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 148 | nn.LayerNormalization(), 149 | nn.Dense(units=dim) 150 | ]) 151 | 152 | def call(self, x, training=True): 153 | shifted_x = shift(x) 154 | x_with_shifts = tf.concat([x, *shifted_x], axis=-1) 155 | x = self.to_patch_tokens(x_with_shifts) 156 | 157 | return x 158 | 159 | class ViT(Model): 160 | def __init__(self, 161 | image_size, 162 | patch_size, 163 | num_classes, 164 | dim, 165 | depth, 166 | heads, 167 | mlp_dim, 168 | pool='cls', 169 | dim_head=64, 170 | dropout=0.0, 171 | emb_dropout=0.0 172 | ): 173 | super(ViT, self).__init__() 174 | image_height, image_width = pair(image_size) 175 | patch_height, patch_width = pair(patch_size) 176 | 177 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 178 | 179 | num_patches = (image_height // patch_height) * (image_width // patch_width) 180 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 181 | 182 | self.patch_embedding = SPT(dim=dim, patch_size=patch_size) 183 | 184 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 185 | self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim])) 186 | self.dropout = nn.Dropout(rate=emb_dropout) 187 | 188 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 189 | 190 | self.pool = pool 191 | 192 | self.mlp_head = Sequential([ 193 | nn.LayerNormalization(), 194 | nn.Dense(units=num_classes) 195 | ]) 196 | 197 | def call(self, img, training=True, **kwargs): 198 | x = self.patch_embedding(img) 199 | b, n, d = x.shape 200 | 201 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 202 | x = tf.concat([cls_tokens, x], axis=1) 203 | x += self.pos_embedding[:, :(n + 1)] 204 | x = self.dropout(x, training=training) 205 | 206 | x = self.transformer(x, training=training) 207 | 208 | if self.pool == 'mean': 209 | x = tf.reduce_mean(x, axis=1) 210 | else: 211 | x = x[:, 0] 212 | 213 | x = self.mlp_head(x) 214 | 215 | return x 216 | 217 | """ Usage 218 | v = ViT( 219 | image_size = 256, 220 | patch_size = 16, 221 | num_classes = 1000, 222 | dim = 1024, 223 | depth = 6, 224 | heads = 16, 225 | mlp_dim = 2048, 226 | dropout = 0.1, 227 | emb_dropout = 0.1 228 | ) 229 | 230 | img = tf.random.normal(shape=[4, 256, 256, 3]) 231 | preds = v(img) # (4, 1000) 232 | 233 | spt = SPT( 234 | dim = 1024, 235 | patch_size = 16 236 | ) 237 | 238 | tokens = spt(img) # (4, 256, 1024) 239 | """ 240 | -------------------------------------------------------------------------------- /vit_tensorflow/vit_with_patch_merger.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | 7 | from einops import rearrange 8 | from einops.layers.tensorflow import Rearrange, Reduce 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val ,d): 14 | return val if exists(val) else d 15 | 16 | def pair(t): 17 | return t if isinstance(t, tuple) else (t, t) 18 | 19 | def gelu(x, approximate=False): 20 | if approximate: 21 | coeff = tf.cast(0.044715, x.dtype) 22 | return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3)))) 23 | else: 24 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype))) 25 | 26 | class GELU(Layer): 27 | def __init__(self, approximate=False): 28 | super(GELU, self).__init__() 29 | self.approximate = approximate 30 | 31 | def call(self, x, training=True): 32 | return gelu(x, self.approximate) 33 | 34 | 35 | class Identity(Layer): 36 | def __init__(self): 37 | super(Identity, self).__init__() 38 | 39 | def call(self, x, training=True): 40 | return tf.identity(x) 41 | 42 | class PatchMerger(Layer): 43 | def __init__(self, dim, num_tokens_out): 44 | super(PatchMerger, self).__init__() 45 | self.scale = dim ** -0.5 46 | self.norm = nn.LayerNormalization() 47 | self.queries = tf.Variable(tf.random.normal([num_tokens_out, dim])) 48 | 49 | def call(self, x, training=True): 50 | x = self.norm(x) 51 | sim = tf.matmul(self.queries, tf.transpose(x, perm=[0, 2, 1]) * self.scale) 52 | attn = tf.nn.softmax(sim, axis=-1) 53 | x = tf.matmul(attn, x) 54 | 55 | return x 56 | 57 | class PreNorm(Layer): 58 | def __init__(self, fn): 59 | super(PreNorm, self).__init__() 60 | 61 | self.norm = nn.LayerNormalization() 62 | self.fn = fn 63 | 64 | def call(self, x, **kwargs): 65 | return self.fn(self.norm(x), **kwargs) 66 | 67 | class MLP(Layer): 68 | def __init__(self, dim, hidden_dim, dropout=0.0): 69 | super(MLP, self).__init__() 70 | self.net = Sequential([ 71 | nn.Dense(units=hidden_dim), 72 | GELU(), 73 | nn.Dropout(rate=dropout), 74 | nn.Dense(units=dim), 75 | nn.Dropout(rate=dropout) 76 | ]) 77 | 78 | def call(self, x, training=True): 79 | return self.net(x, training=training) 80 | 81 | class Attention(Layer): 82 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 83 | super(Attention, self).__init__() 84 | inner_dim = dim_head * heads 85 | project_out = not (heads == 1 and dim_head == dim) 86 | 87 | self.heads = heads 88 | self.scale = dim_head ** -0.5 89 | 90 | self.attend = nn.Softmax() 91 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 92 | 93 | self.to_out = Sequential([ 94 | nn.Dense(units=dim), 95 | nn.Dropout(rate=dropout) 96 | ]) if project_out else Identity() 97 | 98 | def call(self, x, training=True): 99 | qkv = self.to_qkv(x) 100 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 101 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 102 | 103 | dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale 104 | attn = self.attend(dots) 105 | 106 | x = tf.matmul(attn, v) 107 | x = rearrange(x, 'b h n d -> b n (h d)') 108 | x = self.to_out(x, training=training) 109 | 110 | return x 111 | 112 | class Transformer(Layer): 113 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0, patch_merge_layer=None, patch_merge_num_tokens=8): 114 | super(Transformer, self).__init__() 115 | 116 | self.layers = [] 117 | self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper 118 | self.patch_merger = PatchMerger(dim=dim, num_tokens_out=patch_merge_num_tokens) 119 | 120 | for _ in range(depth): 121 | self.layers.append([ 122 | PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 123 | PreNorm(MLP(dim, mlp_dim, dropout=dropout)) 124 | ]) 125 | 126 | def call(self, x, training=True): 127 | for index, (attn, ff) in enumerate(self.layers): 128 | x = attn(x, training=training) + x 129 | x = ff(x, training=training) + x 130 | 131 | if index == self.patch_merge_layer_index: 132 | x = self.patch_merger(x) 133 | 134 | return x 135 | 136 | class ViT(Model): 137 | def __init__(self, image_size, 138 | patch_size, 139 | num_classes, 140 | dim, 141 | depth, 142 | heads, 143 | mlp_dim, 144 | patch_merge_layer=None, 145 | patch_merge_num_tokens=8, 146 | dim_head=64, 147 | dropout=0.0, 148 | emb_dropout=0.0 149 | ): 150 | super(ViT, self).__init__() 151 | 152 | image_height, image_width = pair(image_size) 153 | patch_height, patch_width = pair(patch_size) 154 | 155 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 156 | 157 | num_patches = (image_height // patch_height) * (image_width // patch_width) 158 | self.patch_embedding = Sequential([ 159 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 160 | nn.Dense(units=dim) 161 | ]) 162 | 163 | self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim])) 164 | self.dropout = nn.Dropout(rate=emb_dropout) 165 | 166 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens) 167 | 168 | self.mlp_head = Sequential([ 169 | Reduce('b n d -> b d', 'mean'), 170 | nn.LayerNormalization(), 171 | nn.Dense(units=num_classes) 172 | ]) 173 | 174 | def call(self, img, training=True, **kwargs): 175 | x = self.patch_embedding(img) 176 | b, n, _ = x.shape 177 | 178 | x += self.pos_embedding[:, :n] 179 | x = self.dropout(x, training=training) 180 | 181 | x = self.transformer(x, training=training) 182 | x = self.mlp_head(x) 183 | 184 | return x 185 | 186 | """ Usage 187 | v = ViT( 188 | image_size = 256, 189 | patch_size = 16, 190 | num_classes = 1000, 191 | dim = 1024, 192 | depth = 12, 193 | heads = 8, 194 | patch_merge_layer = 6, # at which transformer layer to do patch merging 195 | patch_merge_num_tokens = 8, # the output number of tokens from the patch merge 196 | mlp_dim = 2048, 197 | dropout = 0.1, 198 | emb_dropout = 0.1 199 | ) 200 | 201 | img = tf.random.normal(shape=[4, 256, 256, 3]) 202 | preds = v(img) # (4, 1000) 203 | 204 | merger = PatchMerger( 205 | dim = 1024, 206 | num_tokens_out = 8 # output number of tokens 207 | ) 208 | 209 | features = tf.random.normal(shape=[4, 256, 1024]) 210 | x = merger(features) 211 | """ --------------------------------------------------------------------------------