├── .gitignore ├── README.md ├── mlp.py ├── LICENSE.md ├── patch_embed.py ├── attention.py ├── block.py ├── vit.py └── dino_weights.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DINOv2 JAX 2 | This repository contains a port of FAIR's [DINOv2](https://dinov2.metademolab.com/) to JAX, intended for running inference against the pretrained DINO weights. 3 | 4 | Use `dino_weights.py` for loading pretrained weights into a ViT-S JAX model (with the same modifications as are made in the DINO paper). 5 | 6 | > **Warning**: There are currently some minor discrepancies between the output of the JAX model and the original model. The results are mostly identical, and the difference is likely down to numerical differences in the JAX and pytorch implementations, but there are no guarantees of correctness. -------------------------------------------------------------------------------- /mlp.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | 6 | class Mlp(nn.Module): 7 | hidden_features: int = 1536 8 | out_features: int = 384 9 | act_layer: nn.Module = nn.gelu 10 | dropout_rate: float = 0.0 11 | bias: bool = True 12 | 13 | @nn.compact 14 | def __call__(self, x, training: bool = False): 15 | x = nn.Dense(features=self.hidden_features, use_bias=self.bias, name="fc1")(x) 16 | x = self.act_layer(x) 17 | x = nn.Dropout(rate=self.dropout_rate, name="drop1")( 18 | x, deterministic=not training 19 | ) 20 | x = nn.Dense(features=self.out_features, use_bias=self.bias, name="fc2")(x) 21 | x = nn.Dropout(rate=self.dropout_rate, name="drop2")( 22 | x, deterministic=not training 23 | ) 24 | return x 25 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kyle Stachowicz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /patch_embed.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | from typing import Type 6 | 7 | 8 | class PatchEmbed(nn.Module): 9 | img_size: int = 224 10 | patch_size: int = 14 11 | in_channels: int = 3 12 | embed_dim: int = 384 13 | norm_layer: Type[nn.Module] = None 14 | flatten_embedding: bool = True 15 | 16 | @nn.compact 17 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 18 | _, H, W, C = x.shape 19 | patch_H, patch_W = self.patch_size, self.patch_size 20 | assert ( 21 | H % patch_H == 0 and W % patch_W == 0 22 | ), f"Image size ({H}*{W}) cannot be evenly divided by patch size ({patch_H}*{patch_W})." 23 | 24 | x = nn.Conv( 25 | features=self.embed_dim, 26 | kernel_size=(patch_H, patch_W), 27 | strides=(patch_H, patch_W), 28 | name="proj", 29 | padding="VALID", 30 | )(x) 31 | 32 | _, H, W, _ = x.shape 33 | x = jnp.reshape(x, (x.shape[0], -1, x.shape[-1])) 34 | 35 | if self.norm_layer is not None: 36 | x = self.norm_layer(name="norm")(x) 37 | 38 | if not self.flatten_embedding: 39 | x = jnp.reshape(x, (-1, H, W, self.embed_dim)) 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | 6 | class Attention(nn.Module): 7 | num_heads: int = 8 8 | attn_bias: bool = True 9 | attn_drop_rate: float = 0.0 10 | proj_bias: bool = True 11 | proj_drop_rate: float = 0.0 12 | embed_dim: int = 384 13 | 14 | @nn.compact 15 | def __call__(self, x, training: bool = False): 16 | B, N, C = x.shape 17 | assert ( 18 | C == self.embed_dim 19 | ), f"Input embedding dimension ({C}) should match layer embedding dimension ({self.embed_dim})." 20 | qkv = nn.Dense(features=3 * C, use_bias=self.attn_bias, name="qkv")(x) 21 | qkv = jnp.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) 22 | qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) 23 | 24 | q, k, v = tuple(qkv) 25 | 26 | # Attention matrix: (B, H, N, N) 27 | attn = q @ k.transpose((0, 1, 3, 2)) / jnp.sqrt(C // self.num_heads) 28 | attn = nn.softmax(attn, axis=-1) 29 | attn = nn.Dropout(rate=self.attn_drop_rate, name="attn_drop")( 30 | attn, deterministic=not training 31 | ) 32 | 33 | # Output: (B, N, H, C // H) 34 | x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C) 35 | 36 | x = nn.Dense(features=C, use_bias=self.proj_bias, name="proj")(x) 37 | x = nn.Dropout(rate=self.proj_drop_rate, name="proj_drop")( 38 | x, deterministic=not training 39 | ) 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /block.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | from typing import Type 6 | from attention import Attention 7 | from mlp import Mlp 8 | 9 | 10 | class LayerScale(nn.Module): 11 | initial_value: float = 1.0 12 | 13 | @nn.compact 14 | def __call__(self, x): 15 | gamma = self.param( 16 | "gamma", 17 | lambda _, shape: self.initial_value * jnp.ones(shape), 18 | (x.shape[-1],), 19 | ) 20 | return x * gamma 21 | 22 | 23 | class DropPath(nn.Module): 24 | rate: float = 0.0 25 | 26 | @nn.compact 27 | def __call__(self, x, deterministic: bool = False): 28 | if self.rate > 0.0 and not deterministic: 29 | keep_prob = 1.0 - self.rate 30 | shape = (x.shape[0], 1, 1, 1) 31 | random_tensor = jax.random.bernoulli( 32 | self.make_rng("dropout"), keep_prob, shape=shape 33 | ) 34 | return x / keep_prob * random_tensor 35 | else: 36 | return x 37 | 38 | 39 | class Block(nn.Module): 40 | num_heads: int = 6 41 | embed_dim: int = 384 42 | mlp_ratio: float = 4.0 43 | drop_path_rate: float = 0.0 44 | 45 | AttentionClass: Type[nn.Module] = Attention 46 | FfnClass: Type[nn.Module] = Mlp 47 | 48 | @nn.compact 49 | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: 50 | def attn_residual_func(x: jnp.ndarray) -> jnp.ndarray: 51 | x = nn.LayerNorm(name="norm1")(x) 52 | x = self.AttentionClass( 53 | num_heads=self.num_heads, embed_dim=self.embed_dim, name="attn" 54 | )(x, training=training) 55 | x = LayerScale(name="ls1")(x) 56 | return x 57 | 58 | def ffn_residual_func(x: jnp.ndarray) -> jnp.ndarray: 59 | x = nn.LayerNorm(name="norm2")(x) 60 | x = self.FfnClass( 61 | hidden_features=int(self.mlp_ratio * self.embed_dim), 62 | out_features=self.embed_dim, 63 | name="mlp", 64 | )(x, training=training) 65 | x = LayerScale(name="ls2")(x) 66 | return x 67 | 68 | if training: 69 | x = x + DropPath( 70 | rate=self.drop_path_rate, name="drop_path1", deterministic=not training 71 | )(attn_residual_func(x)) 72 | x = x + DropPath( 73 | rate=self.drop_path_rate, name="drop_path2", deterministic=not training 74 | )(ffn_residual_func(x)) 75 | else: 76 | x = x + attn_residual_func(x) 77 | x = x + ffn_residual_func(x) 78 | 79 | return x 80 | 81 | 82 | if __name__ == "__main__": 83 | import functools 84 | 85 | attn_cls = functools.partial( 86 | Attention, 87 | num_heads=6, 88 | attn_bias=True, 89 | attn_drop_rate=0.0, 90 | proj_bias=True, 91 | proj_drop_rate=0.0, 92 | ) 93 | block_cls = functools.partial( 94 | Block, 95 | AttentionClass=attn_cls, 96 | drop_path_rate=0.0, 97 | ) 98 | block_def = block_cls() 99 | block_params = block_def.init(jax.random.PRNGKey(0), jnp.ones((32, 16, 384)))[ 100 | "params" 101 | ] 102 | 103 | def print_param(path, param): 104 | print(".".join([p.key for p in path]), param.shape) 105 | 106 | params = jax.tree_util.tree_flatten_with_path(block_params)[0] 107 | for path, param in params: 108 | print_param(path, param) 109 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | import numpy as onp 5 | 6 | from typing import Type 7 | from functools import partial 8 | 9 | from mlp import Mlp 10 | from attention import Attention 11 | from block import Block 12 | from patch_embed import PatchEmbed 13 | 14 | 15 | class DinoViT(nn.Module): 16 | img_size: int = 224 17 | in_channels: int = 3 18 | 19 | patch_size: int = 14 20 | embed_dim: int = 384 21 | 22 | depth: int = 12 23 | 24 | num_heads: int = 6 25 | mlp_ratio: float = 4.0 26 | drop_path_rate: float = 0.0 27 | 28 | BlockClass: Type[nn.Module] = Block 29 | AttentionClass: Type[nn.Module] = Attention 30 | FfnClass: Type[nn.Module] = Mlp 31 | EmbedLayer: Type[nn.Module] = PatchEmbed 32 | 33 | def _interpolate_pos_encoding( 34 | self, x: jnp.ndarray, w: int, h: int, pos_embed: jnp.ndarray 35 | ): 36 | previous_dtype = x.dtype 37 | npatch = x.shape[1] - 1 38 | N = pos_embed.shape[1] - 1 39 | if npatch == N and w == h: 40 | return pos_embed 41 | 42 | class_pos_embed = pos_embed[:, 0] 43 | patch_pos_embed = pos_embed[:, 1:] 44 | dim = x.shape[-1] 45 | w0 = w // self.patch_size 46 | h0 = h // self.patch_size 47 | w0, h0 = w0 + 0.1, h0 + 0.1 48 | 49 | patch_pos_embed = jax.image.resize( 50 | patch_pos_embed.reshape(1, int(N**0.5), int(N**0.5), dim), 51 | (1, w0, h0, dim), 52 | method="bicubic", 53 | ) 54 | patch_pos_embed = jnp.reshape(patch_pos_embed, (1, -1, dim)) 55 | 56 | return jnp.concatenate((class_pos_embed[None], patch_pos_embed), axis=1).astype( 57 | previous_dtype 58 | ) 59 | 60 | @nn.compact 61 | def __call__(self, x, training: bool = False): 62 | B, H, W, C = x.shape 63 | assert H == W == self.img_size, "x size must be (B, {}, {}, {})".format( 64 | self.img_size, self.img_size, C 65 | ) 66 | 67 | x = self.EmbedLayer( 68 | patch_size=self.patch_size, 69 | in_channels=self.in_channels, 70 | embed_dim=self.embed_dim, 71 | name="patch_embed", 72 | )(x) 73 | cls_token = self.param( 74 | "cls_token", nn.initializers.zeros, (1, 1, self.embed_dim) 75 | ) 76 | cls_token = jnp.broadcast_to(cls_token, (x.shape[0], *cls_token.shape[1:])) 77 | x = jnp.concatenate((cls_token, x), axis=1) 78 | 79 | num_patches = (self.img_size // self.patch_size) ** 2 80 | num_tokens = 1 81 | 82 | pos_embed = self.param( 83 | "pos_embed", 84 | nn.initializers.zeros, 85 | (1, num_patches + num_tokens, self.embed_dim), 86 | ) 87 | x = x + self._interpolate_pos_encoding( 88 | x, self.img_size, self.img_size, pos_embed 89 | ) 90 | 91 | for i in range(self.depth): 92 | x = self.BlockClass( 93 | num_heads=self.num_heads, 94 | embed_dim=self.embed_dim, 95 | mlp_ratio=self.mlp_ratio, 96 | drop_path_rate=self.drop_path_rate, 97 | AttentionClass=self.AttentionClass, 98 | FfnClass=self.FfnClass, 99 | name=f"blocks.{i}", 100 | )(x, training=training) 101 | 102 | x_norm = nn.LayerNorm(name="norm")(x) 103 | return { 104 | "x_norm_clstoken": x_norm[:, 0], 105 | "x_norm_patchtokens": x_norm[:, 1:], 106 | "x_prenorm": x, 107 | } 108 | -------------------------------------------------------------------------------- /dino_weights.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import torch 4 | import re 5 | import functools 6 | 7 | from vit import DinoViT 8 | 9 | 10 | def load_vit_params(params_jax: dict, vit_pt: torch.nn.Module): 11 | jax_params_flat, jax_param_pytree = jax.tree_util.tree_flatten_with_path(params_jax) 12 | dinov2_params = {path: param for path, param in vit_pt.named_parameters()} 13 | 14 | no_transpose = { 15 | "cls_token", 16 | "pos_embed", 17 | "mask_token", 18 | } 19 | dinov2_params_flat = [] 20 | for path, param in jax_params_flat: 21 | shape = param.shape 22 | path = ".".join([p.key for p in path]) 23 | path = re.sub(r"\.scale|.kernel", ".weight", path) 24 | if path in dinov2_params: 25 | dinov2_param = dinov2_params[path] 26 | if path not in no_transpose: 27 | if len(shape) == 4: 28 | dinov2_param = torch.permute(dinov2_param, (2, 3, 1, 0)) 29 | else: 30 | dinov2_param = torch.permute( 31 | dinov2_param, tuple(reversed(range(len(shape)))) 32 | ) 33 | if shape != dinov2_param.shape: 34 | print(path, shape, dinov2_params[path]) 35 | dinov2_params_flat.append(jnp.asarray(dinov2_param.detach().numpy())) 36 | dinov2_params.pop(path) 37 | else: 38 | print(path, shape, None) 39 | dinov2_params_flat.append(None) 40 | for path, param in dinov2_params.items(): 41 | print(path, None, param.shape) 42 | 43 | return jax.tree_util.tree_unflatten(jax_param_pytree, dinov2_params_flat) 44 | 45 | 46 | def load_dino_vits(): 47 | num_heads = 6 48 | embed_dim = 384 49 | mlp_ratio = 4 50 | 51 | vit_cls = functools.partial( 52 | DinoViT, 53 | num_heads=num_heads, 54 | embed_dim=embed_dim, 55 | mlp_ratio=mlp_ratio, 56 | depth=12, 57 | img_size=518, 58 | ) 59 | vit_def = vit_cls() 60 | vit_params = vit_def.init(jax.random.PRNGKey(0), jnp.ones((1, 518, 518, 3)))[ 61 | "params" 62 | ] 63 | 64 | dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14") 65 | 66 | params = load_vit_params(vit_params, dinov2_vits14) 67 | 68 | return (vit_def, params) 69 | 70 | 71 | def test_dino_vits(): 72 | import numpy as onp 73 | 74 | image = jax.random.uniform(jax.random.PRNGKey(0), (1, 518, 518, 3)) 75 | jax_vit_def, jax_params = load_dino_vits() 76 | 77 | # JAX: forward pass 78 | image = jax.random.uniform(jax.random.PRNGKey(0), (1, 518, 518, 3)) 79 | embed_jax = jax_vit_def.apply({"params": jax_params}, image, training=False) 80 | embed_jax = onp.asarray(embed_jax["x_norm_patchtokens"]) 81 | 82 | # Torch: forward pass 83 | image_torch = torch.from_numpy(onp.asarray(image.transpose((0, 3, 1, 2)))).cuda() 84 | dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").cuda() 85 | dinov2_vits14 = dinov2_vits14.cuda() 86 | dinov2_vits14.eval() 87 | embed_torch = ( 88 | dinov2_vits14.forward_features(image_torch)["x_norm_patchtokens"] 89 | .detach() 90 | .cpu() 91 | .numpy() 92 | ) 93 | embed_torch2 = ( 94 | dinov2_vits14.forward_features(torch.rand((1, 3, 518, 518), device="cuda"))[ 95 | "x_norm_patchtokens" 96 | ] 97 | .detach() 98 | .cpu() 99 | .numpy() 100 | ) 101 | 102 | cosine_distance = ( 103 | onp.sum(embed_torch * embed_jax) 104 | / onp.linalg.norm(embed_torch) 105 | / onp.linalg.norm(embed_jax) 106 | ) 107 | cosine_distance2 = ( 108 | onp.sum(embed_torch2 * embed_jax) 109 | / onp.linalg.norm(embed_torch2) 110 | / onp.linalg.norm(embed_jax) 111 | ) 112 | 113 | # Cosine distance for the first pair (same image) should be close to 1 114 | assert cosine_distance > 0.999, cosine_distance 115 | # Cosine distance for the second pair (different images) should be further away. 116 | # It might still be close to 1, because random noise is semantically similar. 117 | assert cosine_distance2 < 0.95, cosine_distance2 118 | --------------------------------------------------------------------------------