├── README.md └── dino.py /README.md: -------------------------------------------------------------------------------- 1 | # Clean, Self-contained implementation of DINOv2 2 | 3 | ...That's literally it. I've stripped down orignal [DinoV2](https://github.com/facebookresearch/dinov2) repo to its core.I'm going to add examples and stuff soon. 4 | 5 | -------------------------------------------------------------------------------- /dino.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def to_2tuple(x): 8 | return (x, x) if not isinstance(x, tuple) else x 9 | 10 | 11 | class Mlp(nn.Module): 12 | def __init__( 13 | self, 14 | in_features, 15 | hidden_features=None, 16 | out_features=None, 17 | act_layer=nn.GELU, 18 | bias=True, 19 | ): 20 | super().__init__() 21 | hidden_features = hidden_features or in_features 22 | out_features = out_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.fc2(x) 31 | return x 32 | 33 | 34 | class SwiGLUFFN(nn.Module): 35 | def __init__( 36 | self, in_features, hidden_features=None, out_features=None, bias=True, **kwargs 37 | ): 38 | super().__init__() 39 | out_features = out_features or in_features 40 | hidden_features = hidden_features or in_features 41 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 42 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 43 | 44 | def forward(self, x): 45 | x12 = self.w12(x) 46 | x1, x2 = x12.chunk(2, dim=-1) 47 | hidden = F.silu(x1) * x2 48 | return self.w3(hidden) 49 | 50 | 51 | class PatchEmbed(nn.Module): 52 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 53 | super().__init__() 54 | img_size = to_2tuple(img_size) 55 | patch_size = to_2tuple(patch_size) 56 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 57 | self.num_patches = self.grid_size[0] * self.grid_size[1] 58 | self.proj = nn.Conv2d( 59 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 60 | ) 61 | 62 | def forward(self, x): 63 | x = self.proj(x).flatten(2).transpose(1, 2) 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, num_heads=8, qkv_bias=False): 69 | super().__init__() 70 | self.num_heads = num_heads 71 | head_dim = dim // num_heads 72 | self.scale = head_dim**-0.5 73 | 74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 | 76 | self.proj = nn.Linear(dim, dim) 77 | 78 | def forward(self, x): 79 | B, N, C = x.shape 80 | qkv = ( 81 | self.qkv(x) 82 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 83 | .permute(2, 0, 3, 1, 4) 84 | ) 85 | q, k, v = qkv.unbind(0) 86 | x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) 87 | x = x.transpose(1, 2).reshape(B, N, C) 88 | x = self.proj(x) 89 | 90 | return x 91 | 92 | 93 | class LayerScale(nn.Module): 94 | def __init__( 95 | self, 96 | dim, 97 | init_values=1e-5, 98 | inplace=False, 99 | ) -> None: 100 | super().__init__() 101 | self.inplace = inplace 102 | self.gamma = nn.Parameter(torch.ones(dim)) 103 | 104 | def forward(self, x): 105 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 106 | 107 | 108 | class Block(nn.Module): 109 | def __init__( 110 | self, 111 | dim, 112 | num_heads, 113 | mlp_ratio=4.0, 114 | qkv_bias=False, 115 | ffn_bias=True, 116 | attn_drop=0.0, 117 | init_values=None, 118 | act_layer=nn.GELU, 119 | norm_layer=nn.LayerNorm, 120 | attn_class=Attention, 121 | ffn_layer=Mlp, 122 | ): 123 | super().__init__() 124 | self.norm1 = norm_layer(dim) 125 | self.attn = attn_class( 126 | dim, 127 | num_heads=num_heads, 128 | qkv_bias=qkv_bias, 129 | ) 130 | self.ls1 = LayerScale(dim, init_values=init_values) 131 | 132 | self.norm2 = norm_layer(dim) 133 | mlp_hidden_dim = int(dim * mlp_ratio) 134 | self.mlp = ffn_layer( 135 | in_features=dim, 136 | hidden_features=mlp_hidden_dim, 137 | act_layer=act_layer, 138 | bias=ffn_bias, 139 | ) 140 | self.ls2 = LayerScale(dim, init_values=init_values) 141 | 142 | def forward(self, x): 143 | x = x + self.ls1(self.attn(self.norm1(x))) 144 | x = x + self.ls2(self.mlp(self.norm2(x))) 145 | return x 146 | 147 | 148 | class DinoVisionTransformer(nn.Module): 149 | def __init__( 150 | self, 151 | img_size=224, 152 | patch_size=16, 153 | in_chans=3, 154 | embed_dim=768, 155 | depth=12, 156 | num_heads=12, 157 | mlp_ratio=4.0, 158 | qkv_bias=True, 159 | ffn_bias=True, 160 | init_values=None, 161 | embed_layer=PatchEmbed, 162 | act_layer=nn.GELU, 163 | block_fn=Block, 164 | ffn_layer=Mlp, 165 | num_register_tokens=0, 166 | interpolate_antialias=False, 167 | interpolate_offset=0.1, 168 | ): 169 | super().__init__() 170 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 171 | self.num_features = self.embed_dim = embed_dim 172 | self.num_tokens = 1 173 | self.n_blocks = depth 174 | self.num_heads = num_heads 175 | self.patch_size = patch_size 176 | self.num_register_tokens = num_register_tokens 177 | self.interpolate_antialias = interpolate_antialias 178 | self.interpolate_offset = interpolate_offset 179 | self.patch_embed = embed_layer( 180 | img_size=img_size, 181 | patch_size=patch_size, 182 | in_chans=in_chans, 183 | embed_dim=embed_dim, 184 | ) 185 | 186 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 187 | self.pos_embed = nn.Parameter(torch.zeros(1, 1370, embed_dim)) 188 | self.register_tokens = ( 189 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) 190 | if num_register_tokens 191 | else None 192 | ) 193 | 194 | blocks_list = [ 195 | Block( 196 | dim=embed_dim, 197 | num_heads=num_heads, 198 | mlp_ratio=mlp_ratio, 199 | qkv_bias=qkv_bias, 200 | ffn_bias=ffn_bias, 201 | norm_layer=norm_layer, 202 | act_layer=act_layer, 203 | ffn_layer=ffn_layer, 204 | init_values=init_values, 205 | ) 206 | for i in range(depth) 207 | ] 208 | 209 | self.blocks = nn.ModuleList(blocks_list) 210 | self.norm = norm_layer(embed_dim) 211 | self.head = nn.Identity() 212 | 213 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 214 | 215 | def interpolate_pos_encoding(self, x, w, h): 216 | npatch = x.shape[1] - 1 217 | N = self.pos_embed.shape[1] - 1 218 | if npatch == N and w == h: 219 | return self.pos_embed 220 | pos_embed = self.pos_embed.float() 221 | class_pos_embed = pos_embed[:, 0] 222 | patch_pos_embed = pos_embed[:, 1:] 223 | dim = x.shape[-1] 224 | w0, h0 = w // self.patch_size, h // self.patch_size 225 | M = int(math.sqrt(N)) 226 | assert N == M * M 227 | kwargs = {} 228 | if self.interpolate_offset: 229 | sx, sy = ( 230 | float(w0 + self.interpolate_offset) / M, 231 | float(h0 + self.interpolate_offset) / M, 232 | ) 233 | kwargs["scale_factor"] = (sx, sy) 234 | else: 235 | kwargs["size"] = (w0, h0) 236 | patch_pos_embed = ( 237 | nn.functional.interpolate( 238 | patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), 239 | mode="bicubic", 240 | align_corners=False, 241 | **kwargs, 242 | ) 243 | .permute(0, 2, 3, 1) 244 | .view(1, -1, dim) 245 | ) 246 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( 247 | x.dtype 248 | ) 249 | 250 | def prepare_tokens_with_masks(self, x, masks=None): 251 | B, nc, w, h = x.shape 252 | x = self.patch_embed(x) 253 | if masks is not None: 254 | x = torch.where( 255 | masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x 256 | ) 257 | x = torch.cat((self.cls_token.expand(x.size(0), -1, -1), x), dim=1) 258 | x = x + self.interpolate_pos_encoding(x, w, h) 259 | if self.register_tokens is not None: 260 | x = torch.cat( 261 | (x[:, :1], self.register_tokens.expand(x.size(0), -1, -1), x[:, 1:]), 262 | dim=1, 263 | ) 264 | return x 265 | 266 | def forward_features(self, x, masks=None): 267 | x = self.prepare_tokens_with_masks(x, masks) 268 | for blk in self.blocks: 269 | x = blk(x) 270 | x_norm = self.norm(x) 271 | return { 272 | "x_norm_clstoken": x_norm[:, 0], 273 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], 274 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 275 | "x_prenorm": x, 276 | "masks": masks, 277 | } 278 | 279 | def forward(self, x, is_training=False, masks=None): 280 | ret = self.forward_features(x, masks) 281 | return ret if is_training else self.head(ret["x_norm_clstoken"]) 282 | 283 | 284 | def vit_small(patch_size=14, num_register_tokens=0, **kwargs): 285 | return DinoVisionTransformer( 286 | patch_size=patch_size, 287 | embed_dim=384, 288 | depth=12, 289 | num_heads=6, 290 | mlp_ratio=4, 291 | num_register_tokens=num_register_tokens, 292 | **kwargs, 293 | ) 294 | 295 | 296 | def vit_base(patch_size=14, num_register_tokens=0, **kwargs): 297 | return DinoVisionTransformer( 298 | patch_size=patch_size, 299 | embed_dim=768, 300 | depth=12, 301 | num_heads=12, 302 | mlp_ratio=4, 303 | num_register_tokens=num_register_tokens, 304 | **kwargs, 305 | ) 306 | 307 | 308 | def vit_large(patch_size=14, num_register_tokens=0, **kwargs): 309 | return DinoVisionTransformer( 310 | patch_size=patch_size, 311 | embed_dim=1024, 312 | depth=24, 313 | num_heads=16, 314 | mlp_ratio=4, 315 | num_register_tokens=num_register_tokens, 316 | **kwargs, 317 | ) 318 | 319 | 320 | def vit_giant2(patch_size=14, num_register_tokens=0, **kwargs): 321 | return DinoVisionTransformer( 322 | patch_size=patch_size, 323 | embed_dim=1536, 324 | depth=40, 325 | num_heads=24, 326 | mlp_ratio=8 / 3, 327 | num_register_tokens=num_register_tokens, 328 | ffn_layer=SwiGLUFFN, 329 | **kwargs, 330 | ) 331 | 332 | 333 | if __name__ == "__main__": 334 | 335 | # this will test the similarity and l1 loss between the output of dino and our reimplementation 336 | 337 | model_list = [ 338 | ("dinov2_vits14", vit_small), 339 | ("dinov2_vitb14", vit_base), 340 | ("dinov2_vitl14", vit_large), 341 | ("dinov2_vitg14", vit_giant2), 342 | ] 343 | for model_name, model_fn in model_list[3:]: 344 | dino = torch.hub.load("facebookresearch/dinov2", model_name).cuda() 345 | model = model_fn().cuda() 346 | model.load_state_dict(dino.state_dict()) 347 | 348 | for h, w in [(224, 224), (140, 140), (448, 448)]: 349 | image = torch.randn(1, 3, h, w).cuda() 350 | output_dino = dino(image) 351 | out_ours = model(image) 352 | 353 | cos_sim = F.cosine_similarity(output_dino, out_ours).item() 354 | l1_loss = F.l1_loss(output_dino, out_ours).item() 355 | print(f"Similarity between output_dino and out_ours: {cos_sim}") 356 | print(f"L1 distance between output_dino and out_ours: {l1_loss}") 357 | 358 | assert cos_sim > 0.99 359 | assert l1_loss < 0.01 360 | --------------------------------------------------------------------------------