├── model ├── __init__.py ├── common │ ├── __init__.py │ ├── layer_scale.py │ ├── drop_path.py │ ├── mlp.py │ ├── swiglu_ffn.py │ ├── attention.py │ ├── patch_embed.py │ └── block.py ├── encoder │ ├── __init__.py │ └── vision_transformer.py ├── projector │ ├── __init__.py │ └── tp.py ├── model_utils.py ├── options.py ├── losses.py ├── teacher_norm.py ├── teacher_dropping.py └── dune.py ├── teachers ├── dinov2 │ ├── __init__.py │ ├── README.md │ ├── models │ │ ├── __init__.py │ │ └── vision_transformer.py │ └── layers │ │ ├── __init__.py │ │ ├── layer_scale.py │ │ ├── drop_path.py │ │ ├── mlp.py │ │ ├── swiglu_ffn.py │ │ ├── patch_embed.py │ │ ├── attention.py │ │ └── block.py ├── __init__.py ├── config.py ├── forward.py ├── vit_master.py └── builder.py ├── .gitignore ├── assets ├── dune.png ├── test_image.png └── test_image_patch_pca_dune_vitbase14_448_paper.png ├── scripts ├── teachers │ ├── README.md │ ├── prepare_all.sh │ ├── utils.sh │ ├── multihmr.sh │ ├── mast3r.sh │ └── dinov2.sh ├── train_dune.sh ├── pca_vis.py └── setup_env.sh ├── Makefile ├── utils ├── vis.py ├── optim.py ├── distributed.py ├── metrics.py └── exp.py ├── data ├── paths.py ├── dino2.py ├── utils.py ├── mast3r.py ├── transform.py ├── imagenet.py ├── multihmr.py ├── sampler.py ├── dataset.py └── __init__.py ├── hubconf.py ├── ACKNOWLEDGEMENTS - NLE DUNE.txt ├── Project NLE DUNE LICENSE.txt └── README.md /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /teachers/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ -------------------------------------------------------------------------------- /assets/dune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/dune/main/assets/dune.png -------------------------------------------------------------------------------- /assets/test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/dune/main/assets/test_image.png -------------------------------------------------------------------------------- /teachers/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import TEACHER_CFG 2 | from .builder import build_teachers 3 | from .forward import get_teacher_outputs 4 | -------------------------------------------------------------------------------- /assets/test_image_patch_pca_dune_vitbase14_448_paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/dune/main/assets/test_image_patch_pca_dune_vitbase14_448_paper.png -------------------------------------------------------------------------------- /teachers/dinov2/README.md: -------------------------------------------------------------------------------- 1 | Modeling files for DINOv2-based models. 2 | Adapted from the official repository, [here](https://github.com/facebookresearch/dinov2) -------------------------------------------------------------------------------- /teachers/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model/projector/__init__.py: -------------------------------------------------------------------------------- 1 | from .tp import TransformerProjector 2 | 3 | 4 | def get_projector(input_dim: int, output_dim: int): 5 | assert input_dim > 0, input_dim 6 | assert output_dim > 0, output_dim 7 | return TransformerProjector(input_dim=input_dim, output_dim=output_dim) 8 | -------------------------------------------------------------------------------- /scripts/teachers/README.md: -------------------------------------------------------------------------------- 1 | This folder contains bash scripts to download pretrained teacher models. 2 | 3 | To download all teacher models, use the [_prepare_all.sh_](_prepare_all.sh) script. 4 | Make sure to provide as the first argument, `MODELS_ROOT_DIR`, to set directory where the models will be downloaded. -------------------------------------------------------------------------------- /scripts/teachers/prepare_all.sh: -------------------------------------------------------------------------------- 1 | if [ "$#" -ne 1 ]; then 2 | echo "Usage: bash _prepare_all.sh MODELS_ROOT_DIR" 3 | exit 1 4 | fi 5 | 6 | MODELS_ROOT_DIR=${1} 7 | 8 | umask 000 9 | 10 | bash dinov2.sh ${MODELS_ROOT_DIR} 11 | bash mast3r.sh ${MODELS_ROOT_DIR} # gcc-toolset-9 needed 12 | bash multihmr.sh ${MODELS_ROOT_DIR} -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | def extra_repr(module): 2 | """ 3 | Returns a string representation of the module's attributes. 4 | """ 5 | info = "" 6 | for name, value in module.__dict__.items(): 7 | if isinstance(value, (str, int, float)): 8 | info += f"{name}={value}, " 9 | 10 | if info: 11 | info = info[:-2] # Remove the last comma and space 12 | 13 | return info 14 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .mlp import Mlp 7 | from .patch_embed import PatchEmbed 8 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 9 | from .block import NestedTensorBlock 10 | from .attention import MemEffAttention 11 | -------------------------------------------------------------------------------- /scripts/train_dune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_dir="${1}" 4 | shift 1; args=$(echo "$*") 5 | echo "Args passed to bash ${0##*/}:" 6 | echo "=> ${args}" 7 | 8 | # initialize the conda environment 9 | source ./scripts/setup_env.sh 10 | 11 | umask 002 12 | mkdir -p ${output_dir} 13 | 14 | torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=${N_GPUS} main_dune.py \ 15 | --output_dir=${output_dir} \ 16 | --seed=${RANDOM} \ 17 | ${args} -------------------------------------------------------------------------------- /scripts/teachers/utils.sh: -------------------------------------------------------------------------------- 1 | download_model () { 2 | url=${1} 3 | dir=${2} 4 | 5 | mkdir -p ${dir} 6 | echo "==> Model directory: ${dir}" 7 | 8 | cd ${dir} 9 | wget -q -O model.pth ${url} 10 | model_ckpt="${dir}/model.pth" 11 | if [[ ! -f "${model_ckpt}" ]]; then 12 | echo "==> Couldn't download model checkpoint, please see the bash script for possible further instructions." 13 | else 14 | echo "==> Model checkpoint: ${model_ckpt}" 15 | fi 16 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | export YAMLFIX_INDENT_MAPPING := 4 2 | export YAMLFIX_INDENT_OFFSET := 4 3 | export YAMLFIX_INDENT_SEQUENCE := 4 4 | 5 | 6 | format: 7 | find . -type f -name "*.py" | xargs black 8 | 9 | find . -type f -name "*.py" | xargs isort \ 10 | --profile black \ 11 | --atomic \ 12 | --star-first \ 13 | --only-sections \ 14 | --order-by-type \ 15 | --use-parentheses \ 16 | --lines-after-imports=2 \ 17 | --known-local-folder=model,data,teachers,utils 18 | 19 | 20 | clean: 21 | # Remove __pycache__ folders 22 | find . -type d -name "__pycache__" -print0 -exec rm -rf {} \; 23 | 24 | 25 | test: 26 | python -m unittest discover -s tests -p "test_*.py" 27 | 28 | 29 | conda: 30 | conda env export > environment.yaml -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from torchvision.utils import save_image 6 | 7 | 8 | def save_batched_image(image, save_path, nrow=8, padding=3, normalize=True): 9 | """ 10 | Save a batch of images to a grid. 11 | """ 12 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 13 | 14 | kwargs = { 15 | "nrow": nrow, 16 | "padding": padding, 17 | "normalize": normalize, 18 | "scale_each": normalize, 19 | "pad_value": 255, 20 | } 21 | 22 | save_image(image, save_path, **kwargs) 23 | 24 | 25 | def plot_arr(arr: np.ndarray, save_path: str, dpi: int = 300): 26 | plt.close() 27 | plt.figure() 28 | plt.plot(arr) 29 | plt.grid() 30 | plt.savefig(save_path, dpi=dpi, bbox_inches="tight") 31 | np.save(save_path.replace(".png", ".npy"), arr) 32 | -------------------------------------------------------------------------------- /model/options.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class EncoderOptions: 6 | """This class encompasses the most important options for specifying an 7 | encoder module. Most likely, this should contain the details to create 8 | a ViT-based encoder. 9 | """ 10 | 11 | arch: str = "vit_base" 12 | image_size: int = 336 13 | patch_size: int = 14 14 | num_register_tokens: int = 4 15 | layerscale_init: float = 0.0001 16 | qkv_bias: bool = True 17 | ln_affine: bool = True 18 | 19 | 20 | @dataclass 21 | class ProjectorOptions: 22 | """This class encompasses the most important options for specifying a 23 | projector module. 24 | """ 25 | 26 | input_dim: int 27 | output_dim: int 28 | num_blocks: int = 1 29 | num_heads: int = 12 30 | layerscale_init: float = 0.0001 31 | qkv_bias: bool = True 32 | ln_affine: bool = True 33 | scale: float = 1.0 34 | -------------------------------------------------------------------------------- /scripts/teachers/multihmr.sh: -------------------------------------------------------------------------------- 1 | # Publication title : Multi-HMR: Multi-Person Whole-Body Human Mesh Recovery in a Single Shot 2 | # Publication URL : https://arxiv.org/abs/2402.14654 3 | # Official Github repo : https://github.com/naver/multi-hmr 4 | 5 | ################################################## 6 | # Code for preparing model(s): 7 | ################################################## 8 | # Arguments: 9 | # root_dir: path where all models are saved 10 | root_dir=${1} 11 | 12 | source ./utils.sh 13 | 14 | for arch in "vitlarge14_672"; do 15 | 16 | if [[ ${arch} == "vitlarge14_672" ]]; then 17 | model_url="https://download.europe.naverlabs.com/ComputerVision/MultiHMR/multiHMR_672_L.pt" 18 | else 19 | echo "==> Unknown architecture: ${arch}" 20 | exit 21 | fi 22 | 23 | echo "Preparing Multi-HMR - ${arch}" 24 | model_dir=${root_dir}/multihmr/${arch} 25 | download_model ${model_url} ${model_dir} 26 | 27 | done -------------------------------------------------------------------------------- /model/common/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | 14 | class LayerScale(nn.Module): 15 | def __init__( 16 | self, 17 | dim: int, 18 | init_values: Union[float, Tensor] = 1e-5, 19 | inplace: bool = False, 20 | ) -> None: 21 | super().__init__() 22 | self.init_values = init_values 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def extra_repr(self) -> str: 27 | return "init_values={}, inplace={}".format(self.init_values, self.inplace) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 31 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | 14 | class LayerScale(nn.Module): 15 | def __init__( 16 | self, 17 | dim: int, 18 | init_values: Union[float, Tensor] = 1e-5, 19 | inplace: bool = False, 20 | ) -> None: 21 | super().__init__() 22 | self.init_values = init_values 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def extra_repr(self) -> str: 27 | return "init_values={}, inplace={}".format(self.init_values, self.inplace) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 31 | -------------------------------------------------------------------------------- /model/common/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * ( 19 | x.ndim - 1 20 | ) # work with diff dim tensors, not just 2D ConvNets 21 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 22 | if keep_prob > 0.0: 23 | random_tensor.div_(keep_prob) 24 | output = x * random_tensor 25 | return output 26 | 27 | 28 | class DropPath(nn.Module): 29 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 30 | 31 | def __init__(self, drop_prob=None): 32 | super(DropPath, self).__init__() 33 | self.drop_prob = drop_prob 34 | 35 | def forward(self, x): 36 | return drop_path(x, self.drop_prob, self.training) 37 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * ( 19 | x.ndim - 1 20 | ) # work with diff dim tensors, not just 2D ConvNets 21 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 22 | if keep_prob > 0.0: 23 | random_tensor.div_(keep_prob) 24 | output = x * random_tensor 25 | return output 26 | 27 | 28 | class DropPath(nn.Module): 29 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 30 | 31 | def __init__(self, drop_prob=None): 32 | super(DropPath, self).__init__() 33 | self.drop_prob = drop_prob 34 | 35 | def forward(self, x): 36 | return drop_path(x, self.drop_prob, self.training) 37 | -------------------------------------------------------------------------------- /model/common/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_params_groups(model, save_file_path=None): 5 | """ 6 | Returns two parameters group, one for regularized parameters with weight decay, 7 | and another for unregularized parameters. 8 | """ 9 | regularized = [] 10 | not_regularized = [] 11 | 12 | fp = None 13 | if save_file_path is not None: 14 | fp = open(save_file_path, "w") 15 | 16 | for name, param in model.named_parameters(): 17 | if not param.requires_grad: 18 | continue 19 | 20 | if name.endswith(".bias") or len(param.shape) == 1: 21 | regstat = "Not Regularized" 22 | not_regularized.append(param) 23 | else: 24 | regstat = "Regularized" 25 | regularized.append(param) 26 | 27 | if fp is not None: 28 | fp.write("{} - {} - {}\n".format(name, list(param.shape), regstat)) 29 | 30 | if fp is not None: 31 | fp.flush() 32 | fp.close() 33 | 34 | return [{"params": regularized}, {"params": not_regularized, "weight_decay": 0.0}] 35 | 36 | 37 | def clip_gradients(model, clip): 38 | norms = [] 39 | for n, p in model.named_parameters(): 40 | if p.grad is None: 41 | continue 42 | 43 | param_norm = p.grad.data.norm(p=2) 44 | norms.append(param_norm) 45 | clip_coef = clip / (param_norm + 1e-6) 46 | if clip_coef < 1: 47 | p.grad.data.mul_(clip_coef) 48 | 49 | return torch.stack(norms) 50 | -------------------------------------------------------------------------------- /teachers/config.py: -------------------------------------------------------------------------------- 1 | from .dinov2.models.vision_transformer import vit_large as dinov2_vitlarge 2 | from .vit_master import mast3r 3 | 4 | 5 | TEACHER_CFG = { 6 | "dino2reg_vitlarge_14": { 7 | "loader": dinov2_vitlarge, 8 | "ckpt_path": "/path/to/dinov2reg/vitlarge14/checkpoint.pth", 9 | "ckpt_key": "", 10 | "num_features": 1024, 11 | "image_size": 518, 12 | "patch_size": 14, 13 | "num_register_tokens": 4, 14 | "init_values": 1, 15 | "token_types": ["cls", "patch"], 16 | }, 17 | "mast3r_vitlarge_16": { 18 | "loader": mast3r, 19 | "ckpt_path": "/path/to/mast3r/vitlarge16/checkpoint.pth", 20 | "ckpt_key": None, 21 | "code_dir": "/path/to/mast3r/code", # Path to the MAST3R code, downloaded by scripts/teachers/mast3r.sh 22 | "num_features": 1024, 23 | "image_size": 512, # shortest side 24 | "patch_size": 16, 25 | "token_types": ["patch"], # ignore the cls token 26 | }, 27 | "multihmr_vitlarge_14_672": { 28 | "loader": dinov2_vitlarge, 29 | "ckpt_path": "/path/to/multihmr/vitlarge14_672/checkpoint.pth", 30 | "ckpt_key": "model_state_dict", 31 | "num_features": 1024, 32 | "image_size": 518, # model has 37 ** 2 + 1 positional embeddings, but fine-tuned at resolution 672 33 | "patch_size": 14, 34 | "num_register_tokens": 0, 35 | "init_values": 1, 36 | "token_types": ["patch"], # ignore the cls token 37 | }, 38 | } 39 | -------------------------------------------------------------------------------- /scripts/teachers/mast3r.sh: -------------------------------------------------------------------------------- 1 | # Publication title : Grounding Image Matching in 3D with MASt3R 2 | # Publication URL : https://arxiv.org/abs/2406.09756 3 | # Official Github repo : https://github.com/naver/mast3r 4 | 5 | ################################################## 6 | # Code for preparing model(s): 7 | ################################################## 8 | # Arguments: 9 | # root_dir: path where all models are saved 10 | root_dir=${1} 11 | 12 | pwd=${PWD} 13 | source ./utils.sh 14 | 15 | for arch in "vitlarge16"; do 16 | 17 | if [[ ${arch} == "vitlarge16" ]]; then 18 | # link for the pretrained model 19 | model_url="https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" 20 | else 21 | echo "==> Unknown architecture: ${arch}" 22 | exit 23 | fi 24 | 25 | echo "Preparing Mast3r - ${arch}" 26 | model_dir=${root_dir}/mast3r/${arch} 27 | download_model ${model_url} ${model_dir} 28 | done 29 | 30 | # download the code, will be used in teachers/vit_master.py 31 | cd ${root_dir}/mast3r 32 | git clone --recursive https://github.com/naver/mast3r code 33 | 34 | # compile the cuda kernels for RoPE 35 | # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime. 36 | cd code/dust3r/croco/models/curope/ 37 | # do not forget to call before: 38 | # - scl enable gcc-toolset-9 bash 39 | echo "==> Compiling cuda kernels for RoPE" 40 | python setup.py build_ext --inplace 41 | 42 | cd ${pwd} -------------------------------------------------------------------------------- /scripts/teachers/dinov2.sh: -------------------------------------------------------------------------------- 1 | # Publication title : DINOv2: Learning Robust Visual Features without Supervision 2 | # Publication URL : https://arxiv.org/abs/2304.07193 3 | # Official Github repo : https://github.com/facebookresearch/dinov2 4 | 5 | ################################################## 6 | # Code for preparing model(s): 7 | ################################################## 8 | # Arguments: 9 | # root_dir: path where all models are saved 10 | root_dir=${1} 11 | 12 | source ./utils.sh 13 | 14 | for arch in "vitlarge14_reg"; do 15 | 16 | if [[ ${arch} == "vitsmall14" ]]; then 17 | model_url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" 18 | elif [[ ${arch} == "vitbase14" ]]; then 19 | model_url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" 20 | elif [[ ${arch} == "vitgiant14" ]]; then 21 | model_url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth" 22 | elif [[ ${arch} == "vitlarge14_reg" ]]; then 23 | model_url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth" 24 | elif [[ ${arch} == "vitgiant14_reg" ]]; then 25 | model_url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth" 26 | else 27 | echo "==> Unknown architecture: ${arch}" 28 | exit 29 | fi 30 | 31 | echo "Preparing DINO-v2 - ${arch}" 32 | model_dir=${root_dir}/dinov2/${arch} 33 | download_model ${model_url} ${model_dir} 34 | 35 | done -------------------------------------------------------------------------------- /data/paths.py: -------------------------------------------------------------------------------- 1 | # For logging problematic images 2 | PROBLEMATIC_IMAGES_LOG_FILE = "/path/to/some/json/file.json" 3 | 4 | # List of directories where the ImageNet-1K dataset can be found. 5 | # We monitor distillation loss on this dataset for debugging 6 | # it is not necessary for training DUNE models. 7 | IN1K_DIRS = [ 8 | "/path/to/ilsvrc2012", 9 | ] 10 | 11 | # Path to a pickle file which contains a list of ImageNet-19K images. 12 | # Traversing the entire dataset is slow, so we use a precomputed list. 13 | # This file is not included in the repository, but can be generated 14 | # by first loading up the dataset using torchvision then saving the list of images to a pickle file. 15 | IN19K_PKL_PATH = "/path/to/imagenet19k/images.pkl" 16 | DINOV2_DATASET_PATHS = { 17 | "in19k": IN19K_PKL_PATH, 18 | "gldv2": "/path/to/google-landmarks-dataset-v2", 19 | "mapillarystreet": "/path/to/mapillary-street", 20 | } 21 | 22 | # To prepare the Mast3r datasets, see here: 23 | # https://github.com/naver/mast3r?tab=readme-ov-file#datasets 24 | MAST3R_DATASET_PATHS = { 25 | "ARKitScenesV2": "", 26 | "BlendedMVS": "", 27 | "Co3d_v3": "", 28 | "DL3DV": "", 29 | "Habitat512": "", 30 | "MegaDepthDense": "", 31 | "NLK_MVS": "", 32 | "Niantic": "", 33 | "ScanNetppV2": "", 34 | "TartanAir": "", 35 | "Unreal4K": "", 36 | "VirtualKitti": "", 37 | "WildRgb": "", 38 | } 39 | MAST3R_CACHE_DIR = "/path/to/mast3r/cache/dir" 40 | 41 | HMR_DATASET_PATHS = { 42 | "bedlam": "/path/to/BEDLAM", 43 | "agora": "/path/to/agora/images", 44 | "cuffs": "/path/to/CUFFS", 45 | "ubody": "/path/to/UBody", 46 | "ubody_pkl": "/path/to/ubody.pkl", 47 | } 48 | -------------------------------------------------------------------------------- /teachers/forward.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import Dict, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .config import TEACHER_CFG 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def get_teacher_outputs( 15 | image: torch.Tensor, 16 | teachers: Dict[str, torch.nn.Module], 17 | student_patch_size: int, 18 | tnorms: Optional[nn.ModuleDict], 19 | tnorm_ema_mom: float = 0.0, 20 | ) -> Dict[str, Dict[str, torch.Tensor]]: 21 | 22 | all_tout_dict = defaultdict(dict) 23 | 24 | for tname, tmodel in teachers.items(): 25 | 26 | with torch.inference_mode(): 27 | _image = image 28 | if student_patch_size != tmodel.patch_size: 29 | # resize image for the teacher model 30 | # such that spatial size of its output matches that of student 31 | _image = torch.nn.functional.interpolate( 32 | image, 33 | scale_factor=tmodel.patch_size / student_patch_size, 34 | mode="bicubic", 35 | align_corners=True, 36 | ) 37 | 38 | tout_dict = teachers[tname].forward_features(_image) 39 | 40 | with torch.no_grad(): 41 | for ttype in TEACHER_CFG[tname]["token_types"]: 42 | key = "x_norm_{}{}".format( 43 | ttype, "token" if ttype == "cls" else "tokens" 44 | ) 45 | tout = tout_dict[key] 46 | if tnorms is not None: 47 | tout = tnorms[tname](tout, ttype, tnorm_ema_mom) 48 | all_tout_dict[tname][ttype] = tout 49 | 50 | return all_tout_dict 51 | -------------------------------------------------------------------------------- /model/projector/tp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..common.block import Block 7 | 8 | 9 | class TransformerProjector(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | input_dim: int, 14 | output_dim: int, 15 | num_blocks: int = 1, 16 | num_heads: int = 12, 17 | mlp_ratio: float = 4.0, 18 | layerscale_init: float = 1e-4, 19 | ln_affine: bool = True, 20 | scale: float = 1.0, 21 | **kwargs, 22 | ): 23 | super().__init__() 24 | assert input_dim > 0, input_dim 25 | assert output_dim > 0, output_dim 26 | self.input_dim = input_dim 27 | self.output_dim = output_dim 28 | self.num_heads = num_heads 29 | self.scale = nn.Parameter(torch.ones(1).float()) if scale == 0.0 else scale 30 | norm_layer = partial(nn.LayerNorm, eps=1e-6, elementwise_affine=ln_affine) 31 | 32 | self.blocks = nn.ModuleList( 33 | [ 34 | Block( 35 | dim=input_dim, 36 | num_heads=num_heads, 37 | mlp_ratio=mlp_ratio, 38 | layerscale_init=layerscale_init, 39 | norm_layer=norm_layer, 40 | **kwargs, 41 | ) 42 | for _ in range(num_blocks) 43 | ] 44 | ) 45 | self.linear = nn.Linear(input_dim, output_dim) 46 | 47 | def extra_repr(self): 48 | repr = "num_heads={}, scale={}".format(self.num_heads, self.scale) 49 | return repr 50 | 51 | def forward(self, x): 52 | for blk in self.blocks: 53 | x = blk(x) 54 | 55 | x = self.linear(x) 56 | 57 | return self.scale * x 58 | -------------------------------------------------------------------------------- /data/dino2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import ImageFolderV2, ImageOneFolderV2 4 | from .imagenet import get_imagenet 5 | from .paths import DINOV2_DATASET_PATHS 6 | 7 | 8 | ################################################## 9 | # Dataset getters 10 | 11 | 12 | def get_gldv2(dataset_name, split, transform, **kwargs): 13 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 14 | 15 | target_transform = lambda t: -1 16 | 17 | dataset_path = DINOV2_DATASET_PATHS[dataset_name] 18 | data_dir = os.path.join(dataset_path, "train" if split == "train" else "test") 19 | assert os.path.isdir(data_dir), data_dir 20 | 21 | dataset = ImageFolderV2( 22 | dataset_name, 23 | data_dir, 24 | transform=transform, 25 | target_transform=target_transform, 26 | ) 27 | 28 | return dataset 29 | 30 | 31 | def get_mapillarystreet(dataset_name, split, transform, **kwargs): 32 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 33 | 34 | dataset_path = DINOV2_DATASET_PATHS[dataset_name] 35 | data_dir = os.path.join(dataset_path, "train_val" if split == "train" else "test") 36 | assert os.path.isdir(data_dir), data_dir 37 | 38 | dataset = ImageOneFolderV2( 39 | dataset_name, 40 | data_dir, 41 | transform=transform, 42 | ) 43 | 44 | return dataset 45 | 46 | 47 | AVAILABLE_DATASETS = { 48 | "in19k": { 49 | "train": 13_153_480, 50 | "val": 0, 51 | "getter": get_imagenet, 52 | }, # 20 problematic images removed 53 | "gldv2": { 54 | "train": 4_132_914, 55 | "val": 117_577, 56 | "getter": get_gldv2, 57 | }, 58 | "mapillarystreet": { 59 | "train": 1_205_907, 60 | "val": 23_943, 61 | "getter": get_mapillarystreet, 62 | }, 63 | } 64 | ################################################## 65 | -------------------------------------------------------------------------------- /model/common/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Callable, Optional 7 | 8 | import torch.nn.functional as F 9 | from torch import Tensor, nn 10 | 11 | 12 | class SwiGLUFFN(nn.Module): 13 | def __init__( 14 | self, 15 | in_features: int, 16 | hidden_features: Optional[int] = None, 17 | out_features: Optional[int] = None, 18 | act_layer: Callable[..., nn.Module] = None, 19 | drop: float = 0.0, 20 | bias: bool = True, 21 | ) -> None: 22 | super().__init__() 23 | out_features = out_features or in_features 24 | hidden_features = hidden_features or in_features 25 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 26 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | x12 = self.w12(x) 30 | x1, x2 = x12.chunk(2, dim=-1) 31 | hidden = F.silu(x1) * x2 32 | return self.w3(hidden) 33 | 34 | 35 | class SwiGLUFFNFused(SwiGLUFFN): 36 | def __init__( 37 | self, 38 | in_features: int, 39 | hidden_features: Optional[int] = None, 40 | out_features: Optional[int] = None, 41 | act_layer: Callable[..., nn.Module] = None, 42 | drop: float = 0.0, 43 | bias: bool = True, 44 | ) -> None: 45 | out_features = out_features or in_features 46 | hidden_features = hidden_features or in_features 47 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 48 | super().__init__( 49 | in_features=in_features, 50 | hidden_features=hidden_features, 51 | out_features=out_features, 52 | bias=bias, 53 | ) 54 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | from typing import Sequence 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | 9 | def get_first_available_dir(dir_list: Sequence[str], strict: bool = True) -> str: 10 | # looks for the first available dir in a given list 11 | for d in dir_list: 12 | if os.path.isdir(d): 13 | return d 14 | 15 | if strict: 16 | raise Exception(f"No dir exists in the list: {dir_list}") 17 | else: 18 | return "" 19 | 20 | 21 | def save_pickle(obj, save_path): 22 | with open(save_path, "wb") as fid: 23 | pickle.dump(obj, fid) 24 | 25 | 26 | def load_pickle(save_path): 27 | with open(save_path, "rb") as fid: 28 | obj = pickle.load(fid) 29 | return obj 30 | 31 | 32 | def my_collate(batch): 33 | """ 34 | Extend the default collate function to ignore None samples. 35 | """ 36 | batch = list(filter(lambda x: x is not None, batch)) 37 | return default_collate(batch) 38 | 39 | 40 | def load_json(file_path: str): 41 | with open(file_path, "r", encoding="utf-8") as file: 42 | return json.load(file) 43 | 44 | 45 | def save_json(file_path: str, data): 46 | with open(file_path, "w", encoding="utf-8") as file: 47 | json.dump(data, file, indent=4) 48 | 49 | 50 | def add_str_to_jsonfile(json_file_path: str, item: str): 51 | items = set() 52 | 53 | if os.path.exists(json_file_path): 54 | try: 55 | data = load_json(json_file_path) 56 | items = set(data) 57 | 58 | except json.JSONDecodeError: 59 | # In case the file is empty or not a valid JSON, we'll start with an empty set 60 | pass 61 | 62 | items.add(item) 63 | 64 | save_json(json_file_path, sorted(items)) 65 | 66 | 67 | def normalize_min_max(tensor): 68 | tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) 69 | tensor = tensor.clamp(0, 1) 70 | return tensor 71 | -------------------------------------------------------------------------------- /data/mast3r.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from .dataset import ImageListV2 5 | from .paths import MAST3R_DATASET_PATHS, MAST3R_CACHE_DIR 6 | 7 | 8 | def get_mast3r_dataset( 9 | dataset_name, 10 | split, 11 | transform, 12 | ): 13 | return Mast3rDataset(dataset_name, split=split, transform=transform) 14 | 15 | 16 | AVAILABLE_DATASETS = { 17 | "ARKitScenesV2": {"train": 456_108, "val": 5_307, "getter": get_mast3r_dataset}, 18 | "BlendedMVS": {"train": 98_937, "val": 5_094, "getter": get_mast3r_dataset}, 19 | "Co3d_v3": {"train": 185_100, "val": 5_000, "getter": get_mast3r_dataset}, 20 | "DL3DV": {"train": 208_800, "val": 5_000, "getter": get_mast3r_dataset}, 21 | "Habitat512": {"train": 284_965, "val": 5_035, "getter": get_mast3r_dataset}, 22 | "MegaDepthDense": {"train": 36_949, "val": 3_682, "getter": get_mast3r_dataset}, 23 | "Niantic": {"train": 41_300, "val": 4_600, "getter": get_mast3r_dataset}, 24 | "ScanNetppV2": {"train": 60_188, "val": 5_031, "getter": get_mast3r_dataset}, 25 | "TartanAir": {"train": 136_225, "val": 10_000, "getter": get_mast3r_dataset}, 26 | "Unreal4K": {"train": 14_386, "val": 1_988, "getter": get_mast3r_dataset}, 27 | "VirtualKitti": {"train": 1_200, "val": 300, "getter": get_mast3r_dataset}, 28 | "WildRgb": {"train": 224_400, "val": 5_000, "getter": get_mast3r_dataset}, 29 | } 30 | 31 | 32 | class Mast3rDataset(ImageListV2): 33 | 34 | def __init__(self, dataset_name, split="train", transform=None): 35 | super().__init__(dataset_name, "", [], transform=transform) 36 | 37 | with open( 38 | os.path.join(MAST3R_CACHE_DIR, dataset_name + "_" + split + "_impaths.pkl"), 39 | "rb", 40 | ) as fid: 41 | self.root, self.imlist = pickle.load(fid) 42 | 43 | if not os.path.isdir(self.root): 44 | self.root = MAST3R_DATASET_PATHS[dataset_name] 45 | assert os.path.isdir(self.root), "{} root not found ({})".format( 46 | dataset_name, self.root 47 | ) 48 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import v2 as T 3 | 4 | 5 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 6 | IMAGENET_STD = [0.229, 0.224, 0.225] 7 | NORMALIZE = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 8 | 9 | 10 | def get_train_transform(image_size=224, rrc_scale=(0.08, 1.0), color_aug=True): 11 | transforms = [ 12 | T.ToImage(), 13 | T.RandomResizedCrop( 14 | image_size, 15 | scale=rrc_scale, 16 | interpolation=T.InterpolationMode.BICUBIC, 17 | antialias=True, 18 | ), 19 | T.RandomHorizontalFlip(p=0.5), 20 | ] 21 | 22 | if color_aug: 23 | transforms.extend( 24 | [ 25 | T.RandomApply( 26 | [ 27 | T.ColorJitter( 28 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 29 | ) 30 | ], 31 | p=0.8, 32 | ), 33 | T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2), 34 | T.ToDtype(torch.float32, scale=True), 35 | T.RandomApply([T.GaussianBlur(kernel_size=9, sigma=(0.1, 5.0))], p=0.2), 36 | T.RandomSolarize(threshold=0.5, p=0.2), 37 | ] 38 | ) 39 | else: 40 | transforms.extend( 41 | [ 42 | T.ToDtype(torch.float32, scale=True), 43 | ] 44 | ) 45 | 46 | transforms.append(NORMALIZE) 47 | return T.Compose(transforms) 48 | 49 | 50 | def get_test_transform(image_size, normalize=NORMALIZE, center_crop_size=None): 51 | if center_crop_size is None: 52 | center_crop_size = image_size 53 | 54 | transforms = [ 55 | T.ToImage(), 56 | T.Resize( 57 | image_size, 58 | interpolation=T.InterpolationMode.BICUBIC, 59 | antialias=True, 60 | ), 61 | T.CenterCrop(center_crop_size), 62 | T.ToDtype(torch.float32, scale=True), 63 | normalize, 64 | ] 65 | 66 | return T.Compose(transforms) 67 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | 4 | from model.dune import load_dune_encoder_from_checkpoint, load_dune_from_checkpoint 5 | 6 | 7 | URL_DICT = { 8 | "vitbase_14_448_paper": "https://download.europe.naverlabs.com/dune/dune_vitbase14_448_paper.pth", 9 | "vitbase_14_448": "https://download.europe.naverlabs.com/dune/dune_vitbase14_448.pth", 10 | "vitbase_14_336": "https://download.europe.naverlabs.com/dune/dune_vitbase14_336.pth", 11 | "vitsmall_14_448": "https://download.europe.naverlabs.com/dune/dune_vitsmall14_448.pth", 12 | "vitsmall_14_336": "https://download.europe.naverlabs.com/dune/dune_vitsmall14_336.pth", 13 | } 14 | 15 | 16 | def _load_dune_model_from_url(model_name, encoder_only=False): 17 | if model_name not in URL_DICT: 18 | raise ValueError("Model name '{}' is not recognized.".format(model_name)) 19 | 20 | cache_dir = Path(torch.hub.get_dir()) 21 | cache_dir = cache_dir / "checkpoints" / "dune" 22 | cache_dir.mkdir(parents=True, exist_ok=True) 23 | 24 | url = URL_DICT[model_name] 25 | ckpt_fname = cache_dir / Path(url).name 26 | if not ckpt_fname.exists(): 27 | torch.hub.download_url_to_file(url, ckpt_fname) 28 | 29 | if encoder_only: 30 | return load_dune_encoder_from_checkpoint(ckpt_fname)[0] 31 | else: 32 | return load_dune_from_checkpoint(ckpt_fname)[0] 33 | 34 | 35 | def dune_vitbase_14_448_paper_encoder(): 36 | return _load_dune_model_from_url("vitbase_14_448_paper", encoder_only=True) 37 | 38 | 39 | def dune_vitbase_14_448_paper(): 40 | return _load_dune_model_from_url("vitbase_14_448_paper", encoder_only=False) 41 | 42 | 43 | def dune_vitbase_14_448_encoder(): 44 | return _load_dune_model_from_url("vitbase_14_448", encoder_only=True) 45 | 46 | 47 | def dune_vitbase_14_448(): 48 | return _load_dune_model_from_url("vitbase_14_448", encoder_only=False) 49 | 50 | 51 | def dune_vitbase_14_336_encoder(): 52 | return _load_dune_model_from_url("vitbase_14_336", encoder_only=True) 53 | 54 | 55 | def dune_vitbase_14_336(): 56 | return _load_dune_model_from_url("vitbase_14_336", encoder_only=False) 57 | 58 | 59 | def dune_vitsmall_14_448_encoder(): 60 | return _load_dune_model_from_url("vitsmall_14_448", encoder_only=True) 61 | 62 | 63 | def dune_vitsmall_14_448(): 64 | return _load_dune_model_from_url("vitsmall_14_448", encoder_only=False) 65 | -------------------------------------------------------------------------------- /scripts/pca_vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from PIL import Image 4 | from sklearn.decomposition import PCA 5 | from matplotlib import pyplot as plt 6 | 7 | from model.dune import load_dune_encoder_from_checkpoint 8 | from data.transform import get_test_transform 9 | from data.utils import normalize_min_max 10 | 11 | 12 | architecture = "vitsmall14" 13 | image_size = 448 14 | checkpoint_path = Path("dune_{}_{}_paper.pth".format(architecture, image_size)) 15 | device = torch.device("cpu") 16 | print("Loading DUNE model from checkpoint:", checkpoint_path) 17 | model = load_dune_encoder_from_checkpoint(checkpoint_path)[0] 18 | model = model.eval() 19 | if torch.cuda.is_available(): 20 | print(" - Using GPU") 21 | device = torch.device("cuda") 22 | model = model.to(device) 23 | 24 | print("Loading the test image") 25 | transform = get_test_transform(image_size) 26 | image = Image.open("./assets/test_image.png").convert("RGB") 27 | image = transform(image) 28 | 29 | print("Making a forward pass through the model") 30 | with torch.inference_mode(): 31 | output: dict = model(image.unsqueeze(0).to(device)) 32 | # output is compatibile to that of DINOv2 33 | # output["x_norm_clstoken"].shape --> [1, 768] 34 | # output["x_norm_patchtokens"].shape --> [1, num_patches, 768] 35 | patch_emb = output["x_norm_patchtokens"].detach().cpu().squeeze() 36 | 37 | print("Reducing the dimension of patch embeddings to 3 via PCA") 38 | num_patches_side = int(patch_emb.shape[0] ** 0.5) # assume a square image 39 | pca = PCA(n_components=3, random_state=22) 40 | patch_pca = torch.from_numpy(pca.fit_transform(patch_emb.numpy())) # [num_patches, 3] 41 | patch_pca = patch_pca.reshape([num_patches_side, num_patches_side, 3]).permute(2, 0, 1) 42 | patch_pca = ( 43 | torch.nn.functional.interpolate(patch_pca.unsqueeze(0), image_size, mode="nearest") 44 | .squeeze(0) 45 | .permute(1, 2, 0) 46 | ) # [image_size, image_size, 3] 47 | patch_pca = normalize_min_max(patch_pca) 48 | 49 | print("Visualizing the original image and the PCA-reduced patch embeddings") 50 | plt.close() 51 | fig, axs = plt.subplots(1, 2, dpi=200, constrained_layout=True) 52 | axs[0].imshow(normalize_min_max(image.permute(1, 2, 0))) 53 | axs[1].imshow(patch_pca) 54 | for ax in axs: 55 | ax.axis("off") 56 | plt.savefig( 57 | "./assets/test_image_patch_pca_{}.png".format(checkpoint_path.stem), 58 | bbox_inches="tight", 59 | ) 60 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import warnings 8 | from typing import Callable, Optional 9 | 10 | import torch.nn.functional as F 11 | from torch import Tensor, nn 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def is_enabled() -> bool: 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def get_global_size() -> int: 13 | return dist.get_world_size() if is_enabled() else 1 14 | 15 | 16 | def get_global_rank() -> int: 17 | return dist.get_rank() if is_enabled() else 0 18 | 19 | 20 | def is_main_process() -> bool: 21 | return get_global_rank() == 0 22 | 23 | 24 | def init_distributed_mode(args): 25 | # launched with torchrun 26 | if "WORLD_SIZE" in os.environ: 27 | args.world_size = int(os.environ["WORLD_SIZE"]) 28 | 29 | if "RANK" in os.environ: 30 | args.rank = int(os.environ["RANK"]) 31 | elif "SLURM_PROCID" in os.environ: 32 | args.rank = int(os.environ["SLURM_PROCID"]) 33 | else: 34 | print("Cannot find rank in environment variables") 35 | sys.exit(-1) 36 | 37 | n_gpus_per_node = torch.cuda.device_count() 38 | assert n_gpus_per_node > 0, "No GPU device detected" 39 | 40 | args.gpu = args.rank - n_gpus_per_node * (args.rank // n_gpus_per_node) 41 | 42 | # launched naively with python 43 | elif torch.cuda.is_available(): 44 | print("==> Will run the code on one GPU.") 45 | args.rank, args.gpu, args.world_size = 0, 0, 1 46 | os.environ["MASTER_ADDR"] = "127.0.0.1" 47 | os.environ["MASTER_PORT"] = "12345" 48 | 49 | else: 50 | print("==> Does not support training without GPU.") 51 | sys.exit(1) 52 | 53 | print( 54 | "=> WORLD_SIZE={}, RANK={}, GPU={}, MASTER_ADDR={}, MASTER_PORT={}, INIT_METHOD={}".format( 55 | args.world_size, 56 | args.rank, 57 | args.gpu, 58 | os.environ["MASTER_ADDR"], 59 | os.environ["MASTER_PORT"], 60 | args.dist_url, 61 | ), 62 | flush=True, 63 | ) 64 | 65 | dist.init_process_group( 66 | backend="nccl", 67 | init_method=args.dist_url, 68 | world_size=args.world_size, 69 | rank=args.rank, 70 | ) 71 | dist.barrier() 72 | torch.cuda.set_device(args.gpu) 73 | setup_for_distributed(args.rank == 0) 74 | 75 | 76 | def setup_for_distributed(is_master): 77 | """ 78 | This function disables printing when not in master process 79 | """ 80 | import builtins as __builtin__ 81 | 82 | builtin_print = __builtin__.print 83 | 84 | def print(*args, **kwargs): 85 | force = kwargs.pop("force", False) 86 | if is_master or force: 87 | builtin_print(*args, **kwargs) 88 | 89 | __builtin__.print = print 90 | -------------------------------------------------------------------------------- /model/common/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | from math import sqrt 12 | 13 | import torch.nn.functional as F 14 | from torch import Tensor, nn, softmax 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | class Attention(nn.Module): 21 | def __init__( 22 | self, 23 | dim: int, 24 | num_heads: int = 8, 25 | qkv_bias: bool = False, 26 | proj_bias: bool = True, 27 | attn_drop: float = 0.0, 28 | proj_drop: float = 0.0, 29 | qk_norm: bool = False, 30 | norm_layer: nn.Module = nn.LayerNorm, 31 | ) -> None: 32 | super().__init__() 33 | self.num_heads = num_heads 34 | head_dim = dim // num_heads 35 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 36 | self.attn_drop = attn_drop 37 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 38 | self.proj_drop = nn.Dropout(proj_drop) 39 | self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() 40 | self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() 41 | 42 | def forward(self, x: Tensor, return_attention=False) -> Tensor: 43 | B, N, C = x.shape 44 | qkv = ( 45 | self.qkv(x) 46 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 47 | .permute(2, 0, 3, 1, 4) 48 | ) 49 | 50 | q, k, v = qkv[0], qkv[1], qkv[2] 51 | q, k = self.q_norm(q), self.k_norm(k) 52 | 53 | if return_attention: 54 | scale_factor = 1 / sqrt(q.size(-1)) 55 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 56 | return softmax(attn_weight, dim=-1) 57 | 58 | x = ( 59 | F.scaled_dot_product_attention( 60 | q, k, v, attn_mask=None, dropout_p=self.attn_drop, is_causal=False 61 | ) 62 | .transpose(1, 2) 63 | .reshape(B, N, C) 64 | ) 65 | 66 | x = self.proj(x) 67 | x = self.proj_drop(x) 68 | return x 69 | 70 | 71 | class MemEffAttention(Attention): 72 | def forward(self, x: Tensor, attn_bias=None, return_attention=False) -> Tensor: 73 | if attn_bias is not None: 74 | raise AssertionError("xFormers is required for using nested tensors") 75 | return super().forward(x, return_attention=return_attention) 76 | -------------------------------------------------------------------------------- /scripts/setup_env.sh: -------------------------------------------------------------------------------- 1 | conda_dir="/path/to/conda" # Change this to your conda installation path 2 | conda_env_name="dune" # Change this to your conda environment name 3 | 4 | echo "----------------------------------------------------------------------------------------------------" 5 | hostnamectl 6 | echo "----------------------------------------------------------------------------------------------------" 7 | nvidia-smi 8 | echo "----------------------------------------------------------------------------------------------------" 9 | free -g 10 | echo "----------------------------------------------------------------------------------------------------" 11 | df -h /dev/shm 12 | echo "----------------------------------------------------------------------------------------------------" 13 | df -h /local 14 | echo "----------------------------------------------------------------------------------------------------" 15 | 16 | check_python_pkg () { 17 | pkg=${1} 18 | if python -c "import ${pkg}" &> /dev/null; then 19 | ver=$(python -c "import ${1}; print(${1}.__version__)") 20 | printf "%-30s : %s\n" "${1} version" "${ver}" 21 | fi 22 | } 23 | 24 | check_git () { 25 | if [ -d .git ]; then 26 | printf "%-30s : %s\n" "git commit SHA-1" "$(git rev-parse HEAD)" 27 | else 28 | printf "This code is not in a git repo\n" 29 | fi; 30 | } 31 | 32 | echo "--------------------------------------------------" 33 | source /etc/proxyrc 34 | source ${conda_dir}/bin/activate base 35 | conda activate ${conda_env_name} 36 | export LD_LIBRARY_PATH=${conda_dir}/envs/${conda_env_name}/lib/:${LD_LIBRARY_PATH} 37 | export PYTHONPATH="${PWD}:${PYTHONPATH}" 38 | 39 | printf "%-30s : %s\n" "conda environment name" "${CONDA_DEFAULT_ENV}" 40 | printf "%-30s : %s\n" "conda environment path" "${CONDA_PREFIX}" 41 | check_python_pkg "torch" 42 | check_python_pkg "torchvision" 43 | check_python_pkg "timm" 44 | check_python_pkg "PIL" 45 | check_python_pkg "mmcv" 46 | printf "%-30s : %s\n" "libjpeg-turbo support" "$(python -c "from PIL import features; print(features.check_feature('libjpeg_turbo'))")" 47 | num_cores="$(python -c "import os; print(len(os.sched_getaffinity(0)))")" 48 | printf "%-30s : %s\n" "Number of processors:" "${num_cores}" 49 | check_git 50 | 51 | export ONEDAL_NUM_THREADS=${num_cores} 52 | export OMP_NUM_THREADS=${num_cores} 53 | export MKL_NUM_THREADS=${num_cores} 54 | 55 | echo "PATH: ${PATH}" 56 | echo "LD_LIBRARY_PATH: ${LD_LIBRARY_PATH}" 57 | 58 | echo "--------------------------------------------------" 59 | echo "${SLURM_JOB_NUM_NODES} nodes available" 60 | IFS=','; gpus=($CUDA_VISIBLE_DEVICES); unset IFS; 61 | export N_GPUS=${#gpus[@]} 62 | echo "${N_GPUS} GPU(s) available on this node" 63 | 64 | export MASTER_ADDR=$(hostname) 65 | export MASTER_PORT=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') 66 | echo "MASTER_ADDR=${MASTER_ADDR}, MASTER_PORT=${MASTER_PORT}" -------------------------------------------------------------------------------- /teachers/vit_master.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | 5 | 6 | class Mast3rEncoder(torch.nn.Module): 7 | """ 8 | PyTorch module to keep only the vision encoder of a Mast3r model. 9 | The original class definition is here: 10 | https://github.com/naver/dust3r/blob/9869e71f9165aa53c53ec0979cea1122a569ade4/dust3r/model.py#L46 11 | If renorm_images is True, it is assumed that the input images are already normalized 12 | by the ImageNet normalization, and they will be re-normalized by the Mast3r normalization. 13 | """ 14 | 15 | # default values for Mast3r normalization follow Dust3r, here: 16 | # https://github.com/naver/dust3r/blob/9869e71f9165aa53c53ec0979cea1122a569ade4/dust3r/utils/image.py#L23 17 | _mast3r_image_mean = (0.5, 0.5, 0.5) 18 | _mast3r_image_std = (0.5, 0.5, 0.5) 19 | _imagenet_image_mean = (0.485, 0.456, 0.406) 20 | _imagenet_image_std = (0.229, 0.224, 0.225) 21 | 22 | def __init__(self, code_dir, ckpt_path, renorm_images=True): 23 | super().__init__() 24 | 25 | sys.path.insert(0, code_dir) 26 | from mast3r.model import AsymmetricMASt3R 27 | 28 | self.model = AsymmetricMASt3R.from_pretrained(ckpt_path) 29 | self.renorm_images = renorm_images 30 | self.num_features = self.model.enc_embed_dim 31 | self.patch_size = self.model.patch_embed.patch_size[0] 32 | 33 | self.register_buffer( 34 | "master_image_mean", 35 | torch.tensor(self._mast3r_image_mean).view(1, 3, 1, 1), 36 | ) 37 | self.register_buffer( 38 | "master_image_std", 39 | torch.tensor(self._mast3r_image_std).view(1, 3, 1, 1), 40 | ) 41 | self.register_buffer( 42 | "imagenet_image_mean", 43 | torch.tensor(self._imagenet_image_mean).view(1, 3, 1, 1), 44 | ) 45 | self.register_buffer( 46 | "imagenet_image_std", 47 | torch.tensor(self._imagenet_image_std).view(1, 3, 1, 1), 48 | ) 49 | 50 | def extra_repr(self) -> str: 51 | return "renorm_images={}".format(self.renorm_images) 52 | 53 | def forward_features(self, x): 54 | if self.renorm_images: 55 | # revert already-applied ImageNet normalization 56 | x = x * self.imagenet_image_std + self.imagenet_image_mean 57 | # apply Mast3r normalization 58 | x = (x - self.master_image_mean) / self.master_image_std 59 | 60 | x_patchtokens, _, _ = self.model._encode_image(x, true_shape=None) 61 | 62 | x_clstoken = x_patchtokens.mean(dim=1) 63 | 64 | return { 65 | "x_norm_clstoken": x_clstoken, 66 | "x_norm_regtokens": None, 67 | "x_norm_patchtokens": x_patchtokens, 68 | "x_prenorm": None, 69 | "masks": None, 70 | } 71 | 72 | 73 | def mast3r(code_dir, ckpt_path, **kwargs): 74 | return Mast3rEncoder(code_dir, ckpt_path) 75 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | from torchvision.datasets import VisionDataset 6 | 7 | from .dataset import EmptyDataset, ImageFolderV2, my_pil_loader 8 | from .paths import IN1K_DIRS, IN19K_PKL_PATH, PROBLEMATIC_IMAGES_LOG_FILE 9 | from .utils import get_first_available_dir, add_str_to_jsonfile 10 | 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | ################################################## 16 | # Dataset getters 17 | 18 | 19 | def get_imagenet(dataset_name, split, transform, **kwargs): 20 | assert ( 21 | dataset_name in AVAILABLE_DATASETS 22 | ), "Unknown ImageNet dataset '{}', expected: '{}'".format( 23 | dataset_name, list(AVAILABLE_DATASETS.keys()) 24 | ) 25 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 26 | 27 | if dataset_name.startswith("in19k") and split == "val": 28 | logging.info( 29 | "No validation set for {}, returning empty dataset".format(dataset_name) 30 | ) 31 | return EmptyDataset(dataset_name) 32 | 33 | if dataset_name == "in1k": 34 | dataset_class = ImageFolderV2 35 | data_path = get_first_available_dir(IN1K_DIRS) 36 | data_path = os.path.join(data_path, split) 37 | 38 | elif dataset_name == "in19k": 39 | dataset_class = ImageNetSubset 40 | data_path = IN19K_PKL_PATH 41 | 42 | else: 43 | raise ValueError("Unknown ImageNet dataset: {}".format(dataset_name)) 44 | 45 | assert os.path.exists(data_path), data_path 46 | dataset = dataset_class(dataset_name, data_path, transform=transform) 47 | 48 | return dataset 49 | 50 | 51 | AVAILABLE_DATASETS = { 52 | "in1k": {"train": 1_281_167, "val": 50_000, "getter": get_imagenet}, 53 | "in19k": { 54 | "train": 13_153_480, 55 | "val": 0, 56 | "getter": get_imagenet, 57 | }, # 20 problematic images removed 58 | } 59 | 60 | ################################################## 61 | 62 | 63 | class ImageNetSubset(VisionDataset): 64 | def __init__(self, dataset_name, subset_path, transform=None, imroot=""): 65 | self.dataset_name = dataset_name 66 | self.samples = pickle.load(open(subset_path, "rb")) 67 | self.loader = my_pil_loader 68 | self.transform = transform 69 | self.imroot = imroot 70 | 71 | # for compatibility with VisionDataset.__repr__ 72 | self.root = imroot 73 | self.transforms = transform 74 | 75 | def __len__(self): 76 | return len(self.samples) 77 | 78 | def __getitem__(self, index): 79 | image_path, target = self.samples[index] 80 | image_path = self.imroot + image_path 81 | 82 | try: 83 | image = self.loader(image_path) 84 | except Exception as e: 85 | logger.error("ERROR while loading image {}".format(image_path)) 86 | logger.error("{}".format(e)) 87 | add_str_to_jsonfile(PROBLEMATIC_IMAGES_LOG_FILE, image_path) 88 | return None 89 | 90 | image = self.transform(image) if self.transform else image 91 | 92 | return image, target, self.dataset_name 93 | -------------------------------------------------------------------------------- /teachers/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import OrderedDict 4 | from typing import List, Dict 5 | 6 | import torch 7 | 8 | from .config import TEACHER_CFG 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def build_teachers(teacher_names: List[str]) -> Dict[str, torch.nn.Module]: 15 | teachers = OrderedDict() 16 | 17 | for tname in teacher_names: 18 | logger.info("Loading teacher '{}'".format(tname)) 19 | teachers[tname] = _build_teacher(tname) 20 | 21 | return teachers 22 | 23 | 24 | def _build_teacher(name): 25 | if name not in TEACHER_CFG.keys(): 26 | raise ValueError( 27 | "Unsupported teacher name: {} (supported ones: {})".format( 28 | name, TEACHER_CFG.keys() 29 | ) 30 | ) 31 | 32 | ckpt_path = TEACHER_CFG[name]["ckpt_path"] 33 | ckpt_key = TEACHER_CFG[name]["ckpt_key"] 34 | 35 | if not os.path.exists(ckpt_path): 36 | raise ValueError("Invalid teacher model path/directory: {}".format(ckpt_path)) 37 | 38 | if name.startswith("mast3r"): 39 | code_dir = TEACHER_CFG[name]["code_dir"] 40 | model = TEACHER_CFG[name]["loader"](code_dir, ckpt_path) 41 | 42 | else: 43 | # Teacher models which are loaded from the checkpoint files 44 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) 45 | if ckpt_key != "" and ckpt_key in state_dict.keys(): 46 | state_dict = state_dict[ckpt_key] 47 | 48 | # dinov2 models require some modifications to the state_dict 49 | state_dict = _update_state_dict_for_dinov2_models(name, state_dict) 50 | 51 | model_args = { 52 | "img_size": TEACHER_CFG[name]["image_size"], 53 | "patch_size": TEACHER_CFG[name]["patch_size"], 54 | } 55 | for key in ["init_values", "num_register_tokens"]: 56 | if key in TEACHER_CFG[name]: 57 | model_args[key] = TEACHER_CFG[name][key] 58 | model = TEACHER_CFG[name]["loader"](**model_args) 59 | model.load_state_dict(state_dict, strict=True) 60 | 61 | model = model.cuda() 62 | model = model.eval() 63 | for param in model.parameters(): 64 | param.requires_grad = False 65 | 66 | return model 67 | 68 | 69 | def _update_state_dict_for_dinov2_models(tname, state_dict): 70 | 71 | if tname.startswith("multihmr"): 72 | state_dict = { 73 | k.replace("backbone.encoder.", ""): v 74 | for k, v in state_dict.items() 75 | if k.startswith("backbone.encoder.") 76 | } 77 | 78 | # Add the "blocks.0" prefix to the transformer block keys 79 | state_dict = {k.replace("blocks.", "blocks.0."): v for k, v in state_dict.items()} 80 | 81 | return state_dict 82 | 83 | 84 | def _test_teachers(): 85 | """ 86 | Load all teachers and test if they can be loaded successfully. 87 | """ 88 | 89 | logging.basicConfig(level=logging.INFO) 90 | 91 | for tname in TEACHER_CFG.keys(): 92 | logger.info("Testing teacher '{}'".format(tname)) 93 | _ = _build_teacher(tname) 94 | logger.info(" - Teacher '{}' loaded successfully".format(tname)) 95 | 96 | 97 | if __name__ == "__main__": 98 | _test_teachers() 99 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .teacher_dropping import TeacherDropping 7 | 8 | 9 | def unic_loss( 10 | student_output: Dict[str, Dict[str, torch.Tensor]], 11 | teacher_output: Dict[str, Dict[str, torch.Tensor]], 12 | dset_name: List[str], 13 | tdrop: TeacherDropping = TeacherDropping(method="none", p=0.0), 14 | lam_lcos: float = 0.5, 15 | lam_lsl1: float = 0.5, 16 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 17 | 18 | metric_dict = {} 19 | 20 | loss_pt = loss_per_teacher( 21 | student_output, 22 | teacher_output, 23 | lam_lcos, 24 | lam_lsl1, 25 | metric_dict=metric_dict, 26 | ) 27 | 28 | loss, tcoeffs = tdrop(loss_pt, dset_name) 29 | 30 | for tname in teacher_output.keys(): 31 | metric_dict["t_coeff_{}".format(tname)] = tcoeffs[tname] 32 | 33 | metric_dict["loss"] = loss.item() 34 | 35 | return loss, metric_dict 36 | 37 | 38 | def loss_per_teacher( 39 | student_output: Dict[str, Dict[str, torch.Tensor]], 40 | teacher_output: Dict[str, Dict[str, torch.Tensor]], 41 | lam_cos: float, 42 | lam_sl1: float, 43 | metric_dict={}, 44 | ) -> Dict[str, torch.Tensor]: 45 | loss_pt = {} 46 | 47 | for tname in teacher_output.keys(): 48 | 49 | tout_dict = teacher_output[tname] 50 | sout_dict = student_output[tname] 51 | 52 | losses = [] 53 | 54 | for ttype in tout_dict.keys(): 55 | tout = tout_dict[ttype] 56 | sout = sout_dict[ttype] 57 | 58 | with torch.autocast(device_type=sout.device.type, dtype=torch.float32): 59 | loss_cos = cosine_loss(sout, tout, avg=False) 60 | loss_sl1 = smooth_l1_loss(sout, tout, avg=False) 61 | 62 | loss = lam_cos * loss_cos + lam_sl1 * loss_sl1 63 | 64 | # for patch tokens 65 | if len(loss.shape) == 2: 66 | loss = loss.mean(dim=1) 67 | 68 | losses.append(loss) 69 | 70 | # fmt:off 71 | metric_dict.update( 72 | { 73 | "loss/dist_{}_cos_{}".format(ttype, tname): loss_cos.mean().item(), 74 | "loss/dist_{}_sl1_{}".format(ttype, tname): loss_sl1.mean().item(), 75 | "loss/dist_{}_{}".format(ttype, tname): loss.mean().item(), 76 | } 77 | ) 78 | # fmt:on 79 | 80 | losses = torch.stack(losses, dim=1).mean(dim=1) 81 | loss_pt[tname] = losses 82 | metric_dict.update( 83 | { 84 | "loss/dist_{}".format(tname): losses.mean().item(), 85 | } 86 | ) 87 | 88 | return loss_pt 89 | 90 | 91 | def cosine_loss(pred, target, avg=False): 92 | sim = F.cosine_similarity(pred, target, dim=-1) 93 | loss = 1 - sim 94 | 95 | if avg: 96 | loss = loss.mean() 97 | 98 | return loss 99 | 100 | 101 | def smooth_l1_loss(pred, target, beta=1.0, avg=False): 102 | loss = F.smooth_l1_loss(pred, target, reduction="none", beta=beta).mean(dim=-1) 103 | 104 | if avg: 105 | loss = loss.mean() 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /model/common/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | import torch.nn as nn 13 | from torch import Tensor 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d( 66 | in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW 67 | ) 68 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | _, _, H, W = x.shape 72 | patch_H, patch_W = self.patch_size 73 | 74 | assert ( 75 | H % patch_H == 0 76 | ), f"Input image height {H} is not a multiple of patch height {patch_H}" 77 | assert ( 78 | W % patch_W == 0 79 | ), f"Input image width {W} is not a multiple of patch width: {patch_W}" 80 | 81 | x = self.proj(x) # B C H W 82 | H, W = x.size(2), x.size(3) 83 | x = x.flatten(2).transpose(1, 2) # B HW C 84 | x = self.norm(x) 85 | if not self.flatten_embedding: 86 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 87 | return x 88 | 89 | def flops(self) -> float: 90 | Ho, Wo = self.patches_resolution 91 | flops = ( 92 | Ho 93 | * Wo 94 | * self.embed_dim 95 | * self.in_chans 96 | * (self.patch_size[0] * self.patch_size[1]) 97 | ) 98 | if self.norm is not None: 99 | flops += Ho * Wo * self.embed_dim 100 | return flops 101 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | import torch.nn as nn 13 | from torch import Tensor 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d( 66 | in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW 67 | ) 68 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | _, _, H, W = x.shape 72 | patch_H, patch_W = self.patch_size 73 | 74 | assert ( 75 | H % patch_H == 0 76 | ), f"Input image height {H} is not a multiple of patch height {patch_H}" 77 | assert ( 78 | W % patch_W == 0 79 | ), f"Input image width {W} is not a multiple of patch width: {patch_W}" 80 | 81 | x = self.proj(x) # B C H W 82 | H, W = x.size(2), x.size(3) 83 | x = x.flatten(2).transpose(1, 2) # B HW C 84 | x = self.norm(x) 85 | if not self.flatten_embedding: 86 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 87 | return x 88 | 89 | def flops(self) -> float: 90 | Ho, Wo = self.patches_resolution 91 | flops = ( 92 | Ho 93 | * Wo 94 | * self.embed_dim 95 | * self.in_chans 96 | * (self.patch_size[0] * self.patch_size[1]) 97 | ) 98 | if self.norm is not None: 99 | flops += Ho * Wo * self.embed_dim 100 | return flops 101 | -------------------------------------------------------------------------------- /model/teacher_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | 7 | from utils import distributed 8 | 9 | 10 | class TeacherNorm(nn.Module): 11 | def __init__( 12 | self, token_types: list[str], dim, ema_momentum: float = 0.1, eps: float = 1e-3 13 | ): 14 | super().__init__() 15 | normalizers = {} 16 | for ttype in token_types: 17 | assert ttype in ["cls", "patch"], f"Invalid token type: {ttype}" 18 | agg_dims = [0] if ttype == "cls" else [0, 1] 19 | normalizers[ttype] = StandardNormalizer( 20 | dim, agg_dims, ema_momentum=ema_momentum, eps=eps 21 | ) 22 | self.normalizers = nn.ModuleDict(normalizers) 23 | 24 | def forward(self, x, ttype: str, ema_momentum: Optional[float] = None): 25 | return self.normalizers[ttype](x, ema_momentum) 26 | 27 | 28 | class StandardNormalizer(nn.Module): 29 | def __init__(self, dim, agg_dims, ema_momentum: float = 0.1, eps: float = 1e-3): 30 | super().__init__() 31 | self.agg_dims = agg_dims # which dimensions to aggregate over 32 | self.ema_momentum = ema_momentum 33 | self.ema_momentum_last = 0.0 # set automatically, just for logging 34 | self.eps = eps 35 | 36 | self.register_buffer("mean", torch.zeros(dim)) 37 | self.register_buffer("std", torch.ones(dim)) 38 | 39 | def extra_repr(self) -> str: 40 | repr_str = "eps={}, ema_momentum={:.3f}, ema_momentum_last={:0.3f},\n\tmean={},\n\tstd={}".format( 41 | self.eps, 42 | self.ema_momentum, 43 | self.ema_momentum_last, 44 | self.mean.data, 45 | self.std.data, 46 | ) 47 | return repr_str 48 | 49 | def forward(self, x, ema_momentum: Optional[float] = None): 50 | assert ( 51 | len(x.shape) == self.agg_dims[-1] + 2 52 | ), "Data is not compatible with aggregation dims" 53 | 54 | if ema_momentum is None: 55 | ema_momentum = self.ema_momentum 56 | 57 | self.ema_momentum_last = ema_momentum 58 | 59 | if not self.training or ema_momentum == 0: 60 | # At inference time, or when not updating the statistics 61 | 62 | with torch.autocast(device_type=x.device.type, dtype=torch.float32): 63 | x = (x - self.mean) / torch.clamp(self.std, min=self.eps) 64 | 65 | else: 66 | # During training, update the statistics using EMA. 67 | 68 | # Gather data across all GPUs. 69 | x_all = concat_all_gather(x.contiguous()) 70 | 71 | with torch.autocast(device_type=x.device.type, dtype=torch.float32): 72 | mean = x_all.mean(dim=self.agg_dims, keepdim=False) 73 | std = x_all.std(dim=self.agg_dims, keepdim=False) 74 | x = (x - mean) / torch.clamp(std, min=self.eps) 75 | 76 | self.mean.copy_(self.mean * (1 - ema_momentum) + mean * ema_momentum) 77 | self.std.copy_(self.std * (1 - ema_momentum) + std * ema_momentum) 78 | 79 | return x 80 | 81 | 82 | @torch.no_grad() 83 | def concat_all_gather(tensor): 84 | """ 85 | Performs all_gather operation on the provided tensors. 86 | *** Warning ***: torch.distributed.all_gather has no gradient. 87 | """ 88 | if not distributed.is_enabled(): 89 | return tensor 90 | 91 | tensors_gather = [ 92 | torch.ones_like(tensor) for _ in range(distributed.get_global_size()) 93 | ] 94 | dist.all_gather(tensors_gather, tensor, async_op=False) 95 | 96 | output = torch.cat(tensors_gather, dim=0) 97 | return output 98 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS - NLE DUNE.txt: -------------------------------------------------------------------------------- 1 | Acknowledgements: 2 | 3 | Portions of the Materials may utilize the following subcomponents or dependencies with separate copyright notices and license terms, the use of which is hereby acknowledged: 4 | 5 | --------------------------------------------------------- 6 | PART 1: SEE NOTICES BELOW WITH RESPECT TO SOFTWARE FILES: 7 | --------------------------------------------------------- 8 | 9 | The following training software was made available under the following terms, which as applicable apply to the applicable training software (i.e., any use of this software is subject, as applicable, to any additional terms and conditions of the corresponding software license in addition to the terms and conditions of the License to the Materials): 10 | Name Link 11 | A. dinov2 [https://github.com/facebookresearch/dinov2] 12 | B. dino [https://github.com/facebookresearch/dino] 13 | C. moco [https://github.com/facebookresearch/moco] 14 | D. PyTorch/pytorch [https://github.com/pytorch/pytorch] 15 | E. PyTorch/examples [https://github.com/pytorch/examples] 16 | 17 | ------------------------------------------------- 18 | PART 2: SEE NOTICES BELOW WITH RESPECT TO MODELS: 19 | ------------------------------------------------- 20 | 21 | The following models were derrived from the following models were made available under the following terms, which as applicable apply to the applicable model (i.e., any use of models derived from such models are subject, as applicable, to any additional terms and conditions of the corresponding model license in addition to the terms and conditions of the License to the Materials): 22 | 23 | A. Semantic segmentation decoder is derrived from DINOv2 [https://github.com/facebookresearch/dinov2] 24 | B. Depth estimation decoder is derrived from DINOv2 [https://github.com/facebookresearch/dinov2] 25 | 26 | --------------------------------------------------- 27 | PART 3: SEE NOTICES BELOW WITH RESPECT TO DATASETS: 28 | --------------------------------------------------- 29 | 30 | The following datasets, which are not being distributed herewith, were made available under the following terms, which as applicable apply to the use of checkpoints trained using such datasets (i.e., any use of checkpoints derived from such datasets are subject, as applicable, to any additional terms and conditions of the corresponding dataset license in addition to the terms and conditions of the License to the Materials): 31 | 32 | A. ImageNet-19K [https://image-net.org/download.php] 33 | B. Mapillary [https://www.mapillary.com/dataset/assets/mapillary-object-dataset-research-use-license-2019.pdf] 34 | C. Google Landmarks v2 [https://github.com/cvdfoundation/google-landmark] 35 | D. Habitat [https://matterport.com/legal/matterport-end-user-license-agreement-academic-use-model-data] 36 | E. ARKitScenes [https://github.com/apple/ARKitScenes] 37 | F. Blended MVS [https://github.com/YoYo000/BlendedMVS] 38 | G. MegaDepth [https://github.com/zhengqili/MegaDepth/blob/master/LICENSE] 39 | H. ScanNet++ [http://www.scan-net.org/] 40 | I. CO3D-v2 [https://ai.meta.com/datasets/co3d-downloads/] 41 | J. Map-free [https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset] 42 | K. WildRgb [https://github.com/wildrgbd/wildrgbd] 43 | L. VirtualKitti [https://europe.naverlabs.com/proxy-virtual-worlds-vkitti-2/] 44 | M. Unreal4K [https://github.com/CVLAB-Unibo/neural-disparity-refinement/blob/main/LICENSE] 45 | N. TartanAir [https://theairlab.org/tartanair-dataset/] 46 | O. DL3DV [https://github.com/DL3DV-10K/Dataset] 47 | P. BEDLAM [https://bedlam.is.tue.mpg.de/] 48 | Q. AGORA [https://agora.is.tue.mpg.de/] 49 | R. CUFFS [https://download.europe.naverlabs.com/ComputerVision/MultiHMR/CUFFS/] 50 | S. UBody [https://osx-ubody.github.io/] 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /Project NLE DUNE LICENSE.txt: -------------------------------------------------------------------------------- 1 | PROJECT NAME: NLE DUNE, Copyright (C) 2025 NAVER Corporation. All Rights Reserved. 2 | 3 | You must agree to the terms of this license in order to install and use the “Materials” associated with this project, which may include source code, executable code, models, model checkpoints, and data, together with any documentation and any updates provided at Naver’s discretion. By exercising any rights to the Materials, you accept and agree to be bound by the terms of this license. If you are entering into this license on behalf of a company or other entity, you represent that you are the employee or agent of such company (or other entity) and you have the authority to enter into this license on behalf of such company (or other entity). The Materials are protected by copyright and other intellectual property laws and is licensed, not sold. 4 | 5 | Non-Commercial License 6 | 7 | Subject to any LICENSE EXCEPTIONS, NAVER Corporation (“NAVER”) hereby grants you a non-exclusive, non-sublicensable, non-transferable license to use the Materials, subject to the following conditions: 8 | 9 | (1) SCOPE OF USE: The Materials are used solely for non-commercial purposes (“Purpose”). You may not use the Materials or derivatives thereof for any commercial purpose (i.e., primarily intended for or directed towards commercial advantage or monetary compensation). You may not distribute the Materials or derivatives thereof under different terms and conditions as this License. 10 | 11 | (2) COPYRIGHT: The above copyright notice and this License along with the disclaimer below shall be retained in all copies and derivatives. 12 | 13 | (3) TERM: The License automatically terminates without notice if you fail to comply with its terms or the Purpose no longer exists. You may terminate this License at any time by ceasing to use the Materials. Upon termination you agree to delete any and all copies of the Materials and derivatives. The license to any of your Contributions under (4) will survive termination. 14 | 15 | (4) CONTRIBUTIONS: If you contribute to the project by providing feedback (“Contributions”) by, for example, making comments or a pull request, you agree to grant, and hereby grant, NAVER, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, paid-up, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, adapt, prepare derivative works of, post, distribute, make and have made, sell and transfer your Contributions, and derivative works thereof, for any purpose. This grant by you does not change your rights to use your Contributions. Your Contributions may be used to update the Materials at Naver’s discretion. 16 | 17 | (5) NO IMPLIED LICENSE: Except as otherwise expressly stated in this License, nothing herein shall be construed to grant you any license, by implication, estoppel, or otherwise, to any intellectual property of NAVER, including trademarks, copyrights, patents, or trade secrets. 18 | 19 | (6) LIMITATION OF LIABILITY: THE MATERIALS ARE PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL NAVER BE LIABLE FOR ANY CLAIM, DAMAGES (INCLUDING, BUT NOT LIMITED TO LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 20 | 21 | LICENSE EXCEPTIONS: If the Materials include subcomponents or dependencies with separate copyright notices or license terms, they will be set forth in the file ACKNOWLEDGEMENTS.txt. 22 | 23 | 24 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | from math import sqrt 14 | 15 | import torch.nn.functional as F 16 | from torch import Tensor, nn, softmax 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 23 | try: 24 | if XFORMERS_ENABLED: 25 | from xformers.ops import memory_efficient_attention, unbind 26 | 27 | XFORMERS_AVAILABLE = True 28 | warnings.warn("xFormers is available (Attention)") 29 | else: 30 | warnings.warn("xFormers is disabled (Attention)") 31 | raise ImportError 32 | except ImportError: 33 | XFORMERS_AVAILABLE = False 34 | warnings.warn("xFormers is not available (Attention)") 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__( 39 | self, 40 | dim: int, 41 | num_heads: int = 8, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | attn_drop: float = 0.0, 45 | proj_drop: float = 0.0, 46 | qk_norm: bool = False, 47 | norm_layer: nn.Module = nn.LayerNorm, 48 | ) -> None: 49 | super().__init__() 50 | self.num_heads = num_heads 51 | head_dim = dim // num_heads 52 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 53 | self.attn_drop = attn_drop 54 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() 57 | self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() 58 | 59 | def forward(self, x: Tensor, return_attention=False) -> Tensor: 60 | B, N, C = x.shape 61 | qkv = ( 62 | self.qkv(x) 63 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 64 | .permute(2, 0, 3, 1, 4) 65 | ) 66 | 67 | q, k, v = qkv[0], qkv[1], qkv[2] 68 | q, k = self.q_norm(q), self.k_norm(k) 69 | 70 | if return_attention: 71 | scale_factor = 1 / sqrt(q.size(-1)) 72 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 73 | return softmax(attn_weight, dim=-1) 74 | 75 | x = ( 76 | F.scaled_dot_product_attention( 77 | q, k, v, attn_mask=None, dropout_p=self.attn_drop, is_causal=False 78 | ) 79 | .transpose(1, 2) 80 | .reshape(B, N, C) 81 | ) 82 | 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | return x 86 | 87 | 88 | class MemEffAttention(Attention): 89 | def forward(self, x: Tensor, attn_bias=None, return_attention=False) -> Tensor: 90 | if not XFORMERS_AVAILABLE: 91 | if attn_bias is not None: 92 | raise AssertionError("xFormers is required for using nested tensors") 93 | return super().forward(x, return_attention=return_attention) 94 | 95 | B, N, C = x.shape 96 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 97 | 98 | q, k, v = unbind(qkv, 2) 99 | q, k = self.q_norm(q), self.k_norm(k) 100 | 101 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 102 | x = x.reshape([B, N, C]) 103 | 104 | x = self.proj(x) 105 | x = self.proj_drop(x) 106 | return x 107 | -------------------------------------------------------------------------------- /data/multihmr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | from .dataset import ( 6 | EmptyDataset, 7 | ImageFolderV2, 8 | ImageListV2, 9 | ImageOneFolderV2, 10 | my_pil_loader, 11 | ) 12 | from .paths import HMR_DATASET_PATHS 13 | 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | ################################################## 19 | # Dataset getters 20 | 21 | 22 | def get_bedlam(dataset_name, split, transform, **kwargs): 23 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 24 | 25 | split_dir = {"train": "training", "val": "validation"}[split] 26 | dataset_path = HMR_DATASET_PATHS[dataset_name] 27 | data_dir = os.path.join(dataset_path, split_dir) 28 | assert os.path.isdir(data_dir), data_dir 29 | 30 | def target_transform(t): 31 | return -1 32 | 33 | def is_valid_file(x): 34 | return x.endswith(".png") and not os.path.basename(x).startswith(".") 35 | 36 | def my_pil_loader_with_bedlamfix_rotation(path): 37 | img = my_pil_loader(path) 38 | if "closeup" in path: 39 | img = img.rotate(-90) 40 | return img 41 | 42 | dataset = ImageFolderV2( 43 | dataset_name, 44 | data_dir, 45 | transform=transform, 46 | target_transform=target_transform, 47 | is_valid_file=is_valid_file, 48 | loader=my_pil_loader_with_bedlamfix_rotation, 49 | ) 50 | 51 | return dataset 52 | 53 | 54 | def get_agora(dataset_name, split, transform, **kwargs): 55 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 56 | 57 | dataset_path = HMR_DATASET_PATHS[dataset_name] 58 | data_dir = { 59 | "train": os.path.join(dataset_path, "train"), 60 | "val": os.path.join(dataset_path, "validation"), 61 | }[split] 62 | 63 | def is_valid_file(x): 64 | return not os.path.basename(x).startswith(".") 65 | 66 | dataset = ImageOneFolderV2( 67 | dataset_name, data_dir, transform=transform, is_valid_file=is_valid_file 68 | ) 69 | 70 | return dataset 71 | 72 | 73 | def get_cuffs(dataset_name, split, transform, **kwargs): 74 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 75 | 76 | if split == "val": 77 | logging.info(f"Dataset {dataset_name} does not have a {split} split") 78 | return EmptyDataset(dataset_name) 79 | 80 | dataset_path = HMR_DATASET_PATHS[dataset_name] 81 | dataset = ImageOneFolderV2(dataset_name, dataset_path, transform=transform) 82 | 83 | return dataset 84 | 85 | 86 | def get_ubody(dataset_name, split, transform, **kwargs): 87 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 88 | 89 | dataset_path = HMR_DATASET_PATHS[dataset_name] 90 | imroot = os.path.join(dataset_path, "videos") 91 | pkl_fname = f"{HMR_DATASET_PATHS['ubody_pkl']}/ubody_intra_{'train' if split=='train' else 'test'}.pkl" 92 | 93 | with open(pkl_fname, "rb") as fid: 94 | annot = pickle.load(fid) 95 | 96 | imlist = sorted(list(annot.keys())) 97 | dataset = ImageListV2(dataset_name, imroot, imlist, transform=transform) 98 | 99 | return dataset 100 | 101 | 102 | AVAILABLE_DATASETS = { 103 | "bedlam": { 104 | "train": 353_116, 105 | "val": 31_381, 106 | "getter": get_bedlam, 107 | }, # "hidden" training images are not counted 108 | "agora": { 109 | "train": 14_314, 110 | "val": 1_225, 111 | "getter": get_agora, 112 | }, # "hidden" training images are not counted 113 | "cuffs": {"train": 54_944, "val": 0, "getter": get_cuffs}, 114 | "ubody": {"train": 54_234, "val": 2_016, "getter": get_ubody}, 115 | } 116 | 117 | ################################################## 118 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from PyTorch's DistributedSampler from here 3 | https://github.com/pytorch/pytorch/blob/c2637a7b2656d95712078532c2bc2dd72c4143ff/torch/utils/data/distributed.py 4 | """ 5 | 6 | import logging 7 | import math 8 | from typing import Iterator, Optional, TypeVar 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from torch.utils.data.dataset import Dataset 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | _T_co = TypeVar("_T_co", covariant=True) 17 | logger = logging.getLogger() 18 | 19 | 20 | class InfiniteDistributedSampler(Sampler[_T_co]): 21 | def __init__( 22 | self, 23 | dataset: Dataset, 24 | num_replicas: Optional[int] = None, 25 | rank: Optional[int] = None, 26 | shuffle: bool = True, 27 | seed: int = 0, 28 | drop_last: bool = False, 29 | ) -> None: 30 | if num_replicas is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available.") 33 | num_replicas = dist.get_world_size() 34 | if rank is None: 35 | if not dist.is_available(): 36 | raise RuntimeError("Requires distributed package to be available.") 37 | rank = dist.get_rank() 38 | if rank >= num_replicas or rank < 0: 39 | raise ValueError( 40 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" 41 | ) 42 | self.dataset = dataset 43 | self.num_replicas = num_replicas 44 | self.rank = rank 45 | self.epoch = 0 # Epoch counter 46 | self.drop_last = drop_last 47 | self.shuffle = shuffle 48 | self.seed = seed 49 | 50 | # Calculate number of samples per replica 51 | dataset_length = len(self.dataset) 52 | if self.drop_last and dataset_length % self.num_replicas != 0: 53 | # Split to nearest length that is evenly divisible 54 | self.num_samples = math.floor(dataset_length / self.num_replicas) 55 | else: 56 | self.num_samples = math.ceil(dataset_length / self.num_replicas) 57 | self.total_size = self.num_samples * self.num_replicas 58 | 59 | def __iter__(self) -> Iterator[_T_co]: 60 | 61 | while True: 62 | if self.shuffle: 63 | # Deterministically shuffle based on epoch and seed 64 | g = torch.Generator() 65 | g.manual_seed(self.seed + self.epoch) 66 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 67 | else: 68 | indices = list(range(len(self.dataset))) 69 | 70 | if not self.drop_last: 71 | # Add extra samples to make it evenly divisible 72 | padding_size = self.total_size - len(indices) 73 | if padding_size <= len(indices): 74 | indices += indices[:padding_size] 75 | else: 76 | indices += (indices * math.ceil(padding_size / len(indices)))[ 77 | :padding_size 78 | ] 79 | else: 80 | # Remove tail of data to make it evenly divisible 81 | indices = indices[: self.total_size] 82 | assert len(indices) == self.total_size 83 | 84 | # Subsample 85 | indices = indices[self.rank : self.total_size : self.num_replicas] 86 | assert len(indices) == self.num_samples 87 | 88 | # Yield indices for the current epoch 89 | for idx in indices: 90 | yield idx 91 | 92 | self.epoch += 1 # Move to the next epoch 93 | 94 | logging.info( 95 | "{} epoch set to {} (rank: {})".format( 96 | self.__class__.__name__, self.epoch, self.rank 97 | ) 98 | ) 99 | 100 | def __len__(self) -> int: 101 | # Return the number of samples per epoch 102 | return self.num_samples 103 | 104 | def set_epoch(self, epoch: int) -> None: 105 | r""" 106 | Sets the epoch for this sampler. 107 | 108 | When :attr:`shuffle=True`, this ensures all replicas use a different random ordering 109 | for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. 110 | 111 | Args: 112 | epoch (int): Epoch number. 113 | """ 114 | self.epoch = epoch 115 | -------------------------------------------------------------------------------- /model/teacher_dropping.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Dict 3 | 4 | import torch 5 | 6 | from data import dataset_to_teacher 7 | 8 | 9 | class TeacherDropping: 10 | drop_methods = ["none", "lowest_loss", "own_data", "own+generic_data"] 11 | 12 | def __init__(self, method="lowest_loss", p=0.5): 13 | assert method in self.drop_methods, "Unknown method: {}".format(method) 14 | assert p == 0 or method == "lowest_loss" 15 | self.method = method 16 | self.p = p 17 | 18 | def __call__(self, loss_dict: Dict[str, torch.Tensor], dset_name: List[str]): 19 | """ 20 | Given a dictionary of losses, 21 | where keys are teacher names, and values are 2D tensors of shape (B, N) 22 | (B: batch size and N: number of tokens) 23 | this function aggregates the losses into a single loss tensor. 24 | 25 | Args: 26 | loss: (Dict[str, torch.Tensor] 27 | Loss incurred on each image from each teacher. 28 | dset_name: (List[str], of shape [B]) 29 | Dataset name for each image 30 | """ 31 | teachers = sorted(loss_dict.keys()) 32 | 33 | loss_tensor = torch.stack([loss_dict[key] for key in teachers]) 34 | B = loss_dict[teachers[0]].shape[0] 35 | assert loss_tensor.shape == (len(loss_dict), B) 36 | assert len(dset_name) == B 37 | 38 | if self.method == "none": 39 | # no drop, all teachers contribute to the loss 40 | coeffs = torch.ones_like(loss_tensor) 41 | 42 | elif self.method == "lowest_loss": 43 | # drop teachers with lowest loss, 44 | # i.e., teacher drop as in UNIC 45 | coeffs = torch.stack( 46 | [ 47 | _get_teacher_coefficients_by_loss( 48 | loss_tensor[:, lix], drop_prob=self.p 49 | ) 50 | for lix in range(loss_tensor.shape[1]) 51 | ] 52 | ).t() 53 | 54 | elif self.method in ["own_data", "own+generic_data"]: 55 | # drop teachers that do not match the dataset name 56 | coeffs = torch.zeros_like(loss_tensor) 57 | for si, dname in enumerate(dset_name): 58 | tname_si = dataset_to_teacher(dname) 59 | for ti, tname in enumerate(teachers): 60 | if tname.startswith(tname_si) or ( 61 | self.method == "own+generic_data" 62 | and tname_si.startswith("dino2") 63 | ): 64 | coeffs[ti, si] = 1.0 65 | 66 | else: 67 | raise NotImplementedError( 68 | "{}(method={}) is not implemented".format( 69 | self.__class__.__name__, self.method 70 | ) 71 | ) 72 | 73 | assert coeffs.shape == loss_tensor.shape 74 | 75 | # make sure each image is assigned to at least one teacher 76 | assert torch.all( 77 | coeffs.sum(dim=0) >= 1 78 | ), "{} images in the batch are not assigned to any of the teachers".format( 79 | (coeffs.sum(dim=0) == 0).int().sum() 80 | ) 81 | 82 | ##### 83 | # normalize coefficients such that 84 | # each image contributes to the loss with equal weight 85 | coeffs.div_(coeffs.sum()) 86 | loss = (coeffs.clone().detach() * loss_tensor).sum() 87 | 88 | # sum the coefficients for each teacher 89 | # for logging purposes 90 | coeff_dict = {key: coeff for key, coeff in zip(teachers, coeffs.sum(dim=1))} 91 | 92 | return loss, coeff_dict 93 | 94 | 95 | @torch.no_grad() 96 | def _get_teacher_coefficients_by_loss(losses, drop_prob=0.5): 97 | """ 98 | Given a list of losses from all teachers, return a list for their loss coefficients. 99 | Initially, all coefficients are 1. 100 | Then we flip coefficients for teachers with lowest loss to zeros with a probability drop_prob. 101 | """ 102 | if isinstance(losses, (list, tuple)): 103 | losses = torch.stack(losses) 104 | 105 | # make sure that losses are 1D 106 | assert len(losses.shape) == 1 107 | 108 | coeffs = torch.ones_like(losses, requires_grad=False) 109 | 110 | # find the teacher with the highest loss 111 | max_loss_idx = torch.argmax(losses) 112 | 113 | # go through other teachers and 114 | # flip their coefficients to zeros with a probability drop_prob 115 | for i in range(len(losses)): 116 | if i != max_loss_idx: 117 | p = random.random() 118 | if p < drop_prob: 119 | coeffs[i] = 0 120 | 121 | return coeffs 122 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import deque 3 | from enum import Enum 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils.distributed as distributed 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def accuracy(output, target, topk=(1,)): 15 | """Computes the accuracy over the k top predictions for the specified values of k""" 16 | maxk = max(topk) 17 | batch_size = target.size(0) 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 21 | return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk] 22 | 23 | 24 | class Summary(Enum): 25 | NONE = 0 26 | AVERAGE = 1 27 | SUM = 2 28 | COUNT = 3 29 | 30 | 31 | class AverageMeter(object): 32 | """Computes and stores the average and current value""" 33 | 34 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 35 | self.name = name 36 | self.fmt = fmt 37 | self.summary_type = summary_type 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = self.sum / self.count 51 | 52 | def all_reduce(self): 53 | if torch.cuda.is_available(): 54 | device = torch.device("cuda") 55 | elif torch.backends.mps.is_available(): 56 | device = torch.device("mps") 57 | else: 58 | device = torch.device("cpu") 59 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 60 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 61 | self.sum, self.count = total.tolist() 62 | self.avg = self.sum / self.count 63 | 64 | def __str__(self): 65 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 66 | return fmtstr.format(**self.__dict__) 67 | 68 | def summary(self): 69 | fmtstr = "" 70 | if self.summary_type is Summary.NONE: 71 | fmtstr = "" 72 | elif self.summary_type is Summary.AVERAGE: 73 | fmtstr = "{name} {avg:.3f}" 74 | elif self.summary_type is Summary.SUM: 75 | fmtstr = "{name} {sum:.3f}" 76 | elif self.summary_type is Summary.COUNT: 77 | fmtstr = "{name} {count:.3f}" 78 | else: 79 | raise ValueError("invalid summary type %r" % self.summary_type) 80 | 81 | return fmtstr.format(**self.__dict__) 82 | 83 | 84 | class ProgressMeter(object): 85 | def __init__(self, num_batches, meters, prefix=""): 86 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 87 | self.meters = meters 88 | self.prefix = prefix 89 | 90 | def display(self, batch): 91 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 92 | entries += [str(meter) for meter in self.meters] 93 | logger.info("\t".join(entries)) 94 | 95 | def display_summary(self): 96 | entries = [" *"] 97 | entries += [meter.summary() for meter in self.meters] 98 | logger.info(" ".join(entries)) 99 | 100 | def _get_batch_fmtstr(self, num_batches): 101 | num_digits = len(str(num_batches // 1)) 102 | fmt = "{:" + str(num_digits) + "d}" 103 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 104 | 105 | 106 | class SmoothedValue: 107 | """Track a series of values and provide access to smoothed values over a 108 | window or the global series average. 109 | """ 110 | 111 | def __init__(self, window_size=20, fmt=None): 112 | if fmt is None: 113 | fmt = "{median:.4f} ({global_avg:.4f})" 114 | self.fmt = fmt 115 | self.deque = deque(maxlen=window_size) 116 | self.total = 0.0 117 | self.count = 0 118 | 119 | def reset(self): 120 | self.deque.clear() 121 | self.count = 0 122 | self.total = 0.0 123 | 124 | def update(self, value, num=1): 125 | self.deque.append(value) 126 | self.count += num 127 | self.total += value * num 128 | 129 | def synchronize_between_processes(self): 130 | """ 131 | Distributed synchronization of the metric 132 | Warning: does not synchronize the deque! 133 | """ 134 | if not distributed.is_enabled(): 135 | return 136 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 137 | dist.barrier() 138 | dist.all_reduce(t) 139 | t = t.tolist() 140 | self.count = int(t[0]) 141 | self.total = t[1] 142 | 143 | @property 144 | def median(self): 145 | d = torch.tensor(list(self.deque)) 146 | return d.median().item() 147 | 148 | @property 149 | def avg(self): 150 | d = torch.tensor(list(self.deque), dtype=torch.float32) 151 | return d.mean().item() 152 | 153 | @property 154 | def global_avg(self): 155 | return self.total / self.count 156 | 157 | @property 158 | def max(self): 159 | return max(self.deque) 160 | 161 | @property 162 | def value(self): 163 | return self.deque[-1] 164 | 165 | def __str__(self): 166 | return self.fmt.format( 167 | median=self.median, 168 | avg=self.avg, 169 | global_avg=self.global_avg, 170 | max=self.max, 171 | value=self.value, 172 | ) 173 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from collections import deque 5 | from pathlib import Path 6 | from typing import Dict, List, Tuple, Union 7 | 8 | from PIL import Image 9 | from torch.utils.data import ConcatDataset, Dataset 10 | from torchvision.datasets import ImageFolder 11 | 12 | from .paths import PROBLEMATIC_IMAGES_LOG_FILE 13 | from .utils import add_str_to_jsonfile 14 | 15 | 16 | logger = logging.getLogger() 17 | 18 | 19 | def my_pil_loader(path: str) -> Image.Image: 20 | img = Image.open(path).convert("RGB") 21 | return img 22 | 23 | 24 | class ImageFolderV2(ImageFolder): 25 | """ 26 | Same as ImageFolder, but also returns a dataset_name as third output 27 | """ 28 | 29 | def __init__(self, dataset_name, *args, **kwargs): 30 | if "loader" not in kwargs: 31 | kwargs["loader"] = my_pil_loader 32 | 33 | super().__init__(*args, **kwargs) 34 | self.dataset_name = dataset_name 35 | 36 | def __repr__(self): 37 | lines = super().__repr__().split("\n") 38 | lines.insert(1, " " * self._repr_indent + f"Dataset Name: {self.dataset_name}") 39 | return "\n".join(lines) 40 | 41 | def __getitem__(self, index: int): 42 | path, target = self.samples[index] 43 | try: 44 | # Modify the __getitem__ for DatasetFolder 45 | # instead of calling it directly by "sample, target = super().__getitem__(index)" 46 | sample = self.loader(path) 47 | if self.transform is not None: 48 | sample = self.transform(sample) 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | except Exception as e: 53 | logger.info(f"Error in {self.dataset_name} at index {index}: {e}") 54 | add_str_to_jsonfile(PROBLEMATIC_IMAGES_LOG_FILE, path) 55 | return None 56 | 57 | return sample, target, self.dataset_name 58 | 59 | 60 | class ImageOneFolderV2(ImageFolderV2): 61 | """ 62 | ImageFolder is nice but assume one folder per class. This class does the same but where all images are directly in the root folder. 63 | """ 64 | 65 | def find_classes( 66 | self, directory: Union[str, Path] 67 | ) -> Tuple[List[str], Dict[str, int]]: 68 | return [""], {"": -1} 69 | 70 | 71 | class ImageListV2(Dataset): 72 | 73 | def __init__( 74 | self, dataset_name, root, imlist, transform=None, loader=my_pil_loader 75 | ): 76 | self.dataset_name = dataset_name 77 | self.root = root 78 | self.imlist = imlist 79 | self.transform = transform 80 | self.loader = loader 81 | 82 | def __len__(self): 83 | return len(self.imlist) 84 | 85 | def __getitem__(self, index): 86 | impath = os.path.join(self.root, self.imlist[index]) 87 | 88 | sample = self.loader(impath) 89 | if self.transform is not None: 90 | sample = self.transform(sample) 91 | 92 | return sample, -1, self.dataset_name 93 | 94 | def __repr__(self): 95 | head = "Dataset " + self.__class__.__name__ 96 | body = [f"Dataset Name: {self.dataset_name}"] 97 | body += [f"Number of datapoints: {self.__len__()}"] 98 | if self.root is not None: 99 | body.append(f"Root location: {self.root}") 100 | if hasattr(self, "transform") and self.transform is not None: 101 | body += [repr(self.transform)] 102 | lines = [head] + [" " * 4 + line for line in body] 103 | return "\n".join(lines) 104 | 105 | 106 | class ConcatDatasetv2(ConcatDataset): 107 | def __repr__(self): 108 | head = ["Dataset " + self.__class__.__name__] 109 | body = [f"Number of data points: {self.__len__()} ({self.cumulative_sizes})"] 110 | body2 = [repr(d).replace("\n", "\n\t") for d in self.datasets] 111 | return "\n\t".join(head + body + body2) 112 | 113 | 114 | class DatasetGroup(Dataset): 115 | """ 116 | Given a group of datasets (in a form of dictionary), 117 | returns one sample from each dataset. 118 | Indexing of the group is determined by the largest dataset in the group. 119 | The index of the smaller ones are cyclic. 120 | Warning: "Residual indices" for the smaller datasets are randomly shuffled. 121 | """ 122 | 123 | def __init__(self, datasets: Dict[str, Dataset]): 124 | super().__init__() 125 | self.datasets = datasets 126 | self.res_inds = {k: deque([]) for k in datasets.keys()} 127 | 128 | def init_group_res_index(self, dset_key): 129 | assert dset_key in self.res_inds 130 | 131 | if len(self.res_inds[dset_key]) > 0: 132 | return 133 | 134 | order = list(range(len(self.datasets[dset_key]))) 135 | random.shuffle(order) 136 | self.res_inds[dset_key] = deque(order) 137 | 138 | def __repr__(self): 139 | head = [self.__class__.__name__] 140 | body = [f"Number of data points: {self.__len__()}"] 141 | body2 = [ 142 | "{}: {}".format(k, repr(v).replace("\n", "\n\t")) 143 | for k, v in self.datasets.items() 144 | ] 145 | return "\n\t".join(head + body + body2) 146 | 147 | def __len__(self): 148 | # Indexing is determined by the largest group 149 | return max(len(dataset) for dataset in self.datasets.values()) 150 | 151 | def __getitem__(self, idx): 152 | samples = [] 153 | for dset_key, dataset in self.datasets.items(): 154 | i = idx % len(dataset) 155 | 156 | # residual index for the smaller groups 157 | if idx >= (len(self) // len(dataset)) * len(dataset): 158 | self.init_group_res_index(dset_key) 159 | i = self.res_inds[dset_key].pop() 160 | 161 | samples.append(dataset[i]) 162 | 163 | return samples 164 | 165 | 166 | class EmptyDataset: 167 | def __init__(self, dataset_name): 168 | self.dataset_name = dataset_name 169 | 170 | def __repr__(self): 171 | return "{}(dataset_name={}, len={})".format( 172 | self.__class__.__name__, self.dataset_name, len(self) 173 | ) 174 | 175 | def __len__(self): 176 | return 0 177 | 178 | def __getitem__(self, index): 179 | raise IndexError("Empty dataset has no data") 180 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from .dataset import ConcatDatasetv2, DatasetGroup 5 | from .dino2 import AVAILABLE_DATASETS as AVAILABLE_DATASETS_DINO2 6 | from .imagenet import AVAILABLE_DATASETS as AVAILABLE_DATASETS_IMAGENET 7 | from .mast3r import AVAILABLE_DATASETS as AVAILABLE_DATASETS_MAST3R 8 | from .multihmr import AVAILABLE_DATASETS as AVAILABLE_DATASETS_MULTIHMR 9 | from .transform import get_test_transform, get_train_transform 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | AVAILABLE_DATASETS = { 15 | **AVAILABLE_DATASETS_IMAGENET, 16 | **AVAILABLE_DATASETS_MULTIHMR, 17 | **AVAILABLE_DATASETS_DINO2, 18 | **AVAILABLE_DATASETS_MAST3R, 19 | } 20 | 21 | TEACHER_TO_DATASETS = { 22 | "multihmr": AVAILABLE_DATASETS_MULTIHMR, 23 | "mast3r": AVAILABLE_DATASETS_MAST3R, 24 | "dino2": AVAILABLE_DATASETS_DINO2, 25 | } 26 | 27 | 28 | def dataset_to_teacher(dataset: str) -> str: 29 | if dataset in list(AVAILABLE_DATASETS_DINO2.keys()) + ["in1k"]: 30 | return "dino2" 31 | elif dataset in AVAILABLE_DATASETS_MAST3R: 32 | return "mast3r" 33 | elif dataset in AVAILABLE_DATASETS_MULTIHMR: 34 | return "multihmr" 35 | raise ValueError(f"Unknown dataset: {dataset}") 36 | 37 | 38 | def get_dataset( 39 | dataset_name, split="train", image_size=224, rrc_scale=(0.08, 1.0), color_aug=True 40 | ): 41 | """ 42 | dataset_name can be a list of datasets separated by "," eg in1k,bedlam 43 | it can also be "teacher_balanced" 44 | """ 45 | 46 | if dataset_name == "teacher_balanced": 47 | assert split == "train", "teacher_balanced is only available for training" 48 | return get_teacher_balanced_dataset( 49 | image_size=image_size, rrc_scale=rrc_scale, color_aug=color_aug 50 | ) 51 | 52 | elif dataset_name == "all": 53 | all_teacher_datasets = get_all_teacher_datasets(split) 54 | dataset_name = ",".join(all_teacher_datasets) 55 | logger.info( 56 | "Using all datasets ('{}') for '{}' split".format(dataset_name, split) 57 | ) 58 | 59 | elif dataset_name.startswith("all_except_"): 60 | excluded_dataset = dataset_name.split("all_except_")[1] 61 | assert excluded_dataset in AVAILABLE_DATASETS, ( 62 | "Unknown dataset to exclude: " + excluded_dataset 63 | ) 64 | datasets_to_use = [ 65 | d for d in get_all_teacher_datasets(split) if d != excluded_dataset 66 | ] 67 | dataset_name = ",".join(datasets_to_use) 68 | logger.info( 69 | "Using all datasets except '{}' ('{}') for '{}' split".format( 70 | excluded_dataset, dataset_name, split 71 | ) 72 | ) 73 | 74 | elif dataset_name in ["mast3r", "multihmr", "dino2"]: 75 | _teacher_datasets = TEACHER_TO_DATASETS[dataset_name] 76 | dataset_name = ",".join( 77 | d for d in list(_teacher_datasets.keys()) if split in AVAILABLE_DATASETS[d] 78 | ) 79 | 80 | datasets = [ 81 | get_one_dataset( 82 | dname, 83 | split, 84 | image_size=image_size, 85 | rrc_scale=rrc_scale, 86 | color_aug=color_aug, 87 | ) 88 | for dname in dataset_name.split(",") 89 | ] 90 | 91 | if split == "train": 92 | # at training, we return a ConcatDataset except if there is a single dataset 93 | if len(datasets) == 1: 94 | return datasets[0] 95 | else: 96 | dataset = ConcatDatasetv2(datasets) 97 | return dataset 98 | else: 99 | # at validation/test, we return a list of datasets in any case 100 | return [d for d in datasets if len(d) > 0] 101 | 102 | 103 | def get_one_dataset( 104 | dataset_name, split, image_size=224, rrc_scale=(0.08, 1.0), color_aug=True 105 | ): 106 | 107 | assert dataset_name in AVAILABLE_DATASETS, "Unknown dataset_name: " + dataset_name 108 | expected_length = AVAILABLE_DATASETS.get(dataset_name).get(split) 109 | 110 | transform = ( 111 | get_train_transform( 112 | image_size=image_size, rrc_scale=rrc_scale, color_aug=color_aug 113 | ) 114 | if split == "train" 115 | else get_test_transform(image_size) 116 | ) 117 | 118 | dataset_getter = AVAILABLE_DATASETS.get(dataset_name).get("getter") 119 | logger.info( 120 | "Loading dataset {} (image_size:{}, rrc_scale:{})".format( 121 | dataset_name, image_size, rrc_scale 122 | ) 123 | ) 124 | dataset = dataset_getter(dataset_name, split, transform=transform) 125 | 126 | if len(dataset) != expected_length: 127 | raise Exception( 128 | "Unexpected length of {} for split {} of dataset {}, instead of {}".format( 129 | len(dataset), split, dataset.dataset_name, expected_length 130 | ) 131 | ) 132 | 133 | return dataset 134 | 135 | 136 | def get_teacher_balanced_dataset(image_size=224, rrc_scale=(0.08, 1.0), color_aug=True): 137 | # Concatenate all datasets for each teacher 138 | 139 | logger.info("=> Loading DINOv2 datasets") 140 | dino2_datasets = ConcatDatasetv2( 141 | [ 142 | get_one_dataset( 143 | k, 144 | "train", 145 | image_size=image_size, 146 | rrc_scale=rrc_scale, 147 | color_aug=color_aug, 148 | ) 149 | for k in AVAILABLE_DATASETS_DINO2.keys() 150 | ] 151 | ) 152 | 153 | logger.info("=> Loading MAST3R datasets") 154 | mast3r_datasets = ConcatDatasetv2( 155 | [ 156 | get_one_dataset( 157 | k, 158 | "train", 159 | image_size=image_size, 160 | rrc_scale=rrc_scale, 161 | color_aug=color_aug, 162 | ) 163 | for k in AVAILABLE_DATASETS_MAST3R.keys() 164 | ] 165 | ) 166 | 167 | logger.info("=> Loading Multi-HMR datasets") 168 | multihmr_datasets = ConcatDatasetv2( 169 | [ 170 | get_one_dataset( 171 | k, 172 | "train", 173 | image_size=image_size, 174 | rrc_scale=rrc_scale, 175 | color_aug=color_aug, 176 | ) 177 | for k in AVAILABLE_DATASETS_MULTIHMR.keys() 178 | ] 179 | ) 180 | 181 | # Create a dataset group for all teachers 182 | dataset_dict = { 183 | "dino2": dino2_datasets, 184 | "mast3r": mast3r_datasets, 185 | "multihmr": multihmr_datasets, 186 | } 187 | 188 | dataset = DatasetGroup(dataset_dict) 189 | 190 | return dataset 191 | 192 | 193 | def get_all_teacher_datasets(split: str = "train") -> List[str]: 194 | assert split in ["train", "val"], "split must be 'train' or 'val', got: {split}" 195 | dsets = list(AVAILABLE_DATASETS_DINO2.keys()) 196 | dsets.extend(list(AVAILABLE_DATASETS_MAST3R.keys())) 197 | dsets.extend(list(AVAILABLE_DATASETS_MULTIHMR.keys())) 198 | dsets = [d for d in dsets if split in AVAILABLE_DATASETS[d]] 199 | return dsets 200 | -------------------------------------------------------------------------------- /model/dune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, List, Tuple 4 | from dataclasses import asdict 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from teachers.forward import get_teacher_outputs 10 | 11 | from .options import EncoderOptions, ProjectorOptions 12 | from .encoder import vision_transformer 13 | from .projector.tp import TransformerProjector 14 | from .teacher_norm import TeacherNorm 15 | from .teacher_dropping import TeacherDropping 16 | from .losses import unic_loss 17 | from .model_utils import extra_repr 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class DUNE(nn.Module): 24 | def __init__( 25 | self, 26 | encoder: nn.Module, 27 | projectors: nn.ModuleDict, 28 | teacher_norms: nn.ModuleDict, 29 | apply_last_enc_norm: bool = False, 30 | ): 31 | super().__init__() 32 | self.encoder = encoder 33 | self.projectors = projectors 34 | self.teacher_norms = teacher_norms 35 | self.apply_last_enc_norm = apply_last_enc_norm 36 | 37 | def extra_repr(self) -> str: 38 | return extra_repr(self) 39 | 40 | @property 41 | def patch_size(self) -> int: 42 | return self.encoder.patch_size # type: ignore 43 | 44 | def get_encoder_output( 45 | self, image: torch.Tensor, concat_cls_patch=True 46 | ) -> torch.Tensor | Dict[str, torch.Tensor]: 47 | out_dict = self.encoder(image, apply_norm=self.apply_last_enc_norm) 48 | 49 | enc_out = out_dict 50 | if concat_cls_patch: 51 | enc_out = torch.cat( 52 | [ 53 | out_dict["x_norm_clstoken"].unsqueeze(1), 54 | out_dict["x_norm_patchtokens"], 55 | ], 56 | dim=1, 57 | ) 58 | 59 | return enc_out 60 | 61 | def get_projector_output( 62 | self, 63 | image: torch.Tensor, 64 | teacher: str = "dino2reg_vitlarge_14", 65 | reshape_patch_tokens: bool = True, 66 | # for compatibility with dense evaluations 67 | return_cls_token: bool = True, 68 | return_as_list: bool = False, 69 | ): 70 | 71 | enc_out = self.get_encoder_output(image) 72 | 73 | proj_out = { 74 | "cls": (pout := self.projectors[teacher](enc_out))[:, 0], 75 | "patch": pout[:, 1:], 76 | } 77 | 78 | cls_token = proj_out["cls"] 79 | patch_tokens = proj_out["patch"] 80 | 81 | if reshape_patch_tokens: 82 | B, _, w, h = image.shape 83 | patch_tokens = ( 84 | patch_tokens.reshape(B, w // self.patch_size, h // self.patch_size, -1) 85 | .permute(0, 3, 1, 2) 86 | .contiguous() 87 | ) 88 | 89 | if return_cls_token and return_as_list: 90 | return [cls_token, patch_tokens] 91 | elif return_cls_token: 92 | return cls_token, patch_tokens 93 | elif return_as_list: 94 | return [patch_tokens] 95 | else: 96 | return patch_tokens 97 | 98 | def forward( 99 | self, 100 | image: torch.Tensor, 101 | dset_name: List[str], 102 | teachers: Dict[str, nn.Module], 103 | tdrop: TeacherDropping, 104 | tnorms_ema_mom: float = 0.0, 105 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 106 | enc_out = self.get_encoder_output(image) 107 | proj_out = { 108 | tname: {"cls": (pout := proj(enc_out))[:, 0], "patch": pout[:, 1:]} 109 | for tname, proj in self.projectors.items() 110 | } 111 | 112 | teacher_out = get_teacher_outputs( 113 | image, 114 | teachers, 115 | self.patch_size, 116 | self.teacher_norms, 117 | tnorms_ema_mom if self.training else 0.0, 118 | ) 119 | 120 | loss, loss_dict = unic_loss( 121 | proj_out, 122 | teacher_out, 123 | dset_name, 124 | tdrop, 125 | ) 126 | 127 | return loss, loss_dict 128 | 129 | 130 | def build_encoder(image_size: int, extra_args: Dict) -> nn.Module: 131 | args = asdict(EncoderOptions()) 132 | args.update(extra_args) 133 | args["image_size"] = image_size 134 | return vision_transformer.get_model(**args) 135 | 136 | 137 | def build_projector(input_dim: int, output_dim: int, extra_args: Dict) -> nn.Module: 138 | args = asdict(ProjectorOptions(input_dim, output_dim)) 139 | args.update(extra_args) 140 | return TransformerProjector(**args) 141 | 142 | 143 | def load_student_encoder_from_checkpoint(ckpt_fname, ckpt_key="model"): 144 | assert os.path.isfile(ckpt_fname), "Student checkpoint ({}) not found!".format( 145 | ckpt_fname 146 | ) 147 | ckpt = torch.load(ckpt_fname, "cpu", weights_only=False) 148 | 149 | encoder = build_encoder(ckpt["args"].image_size, ckpt["args"].enc_args) 150 | 151 | state_dict = ckpt.get(ckpt_key, ckpt) 152 | encoder.load_state_dict( 153 | { 154 | k.replace("module.", "") 155 | .replace("_orig_mod.", "") 156 | .replace("encoder.", ""): v 157 | for k, v in state_dict.items() 158 | if "encoder." in k 159 | } 160 | ) 161 | 162 | iter = ckpt.get("iter", 0) 163 | logger.info( 164 | "Loaded student encoder from checkpoint {} trained for {} iterations".format( 165 | ckpt_fname, iter 166 | ) 167 | ) 168 | 169 | return encoder, iter 170 | 171 | 172 | def build_student_from_args(args): 173 | from teachers import TEACHER_CFG 174 | 175 | encoder = build_encoder(args.image_size, args.enc_args) 176 | if not hasattr(args.enc_args, "num_heads"): 177 | args.proj_args["num_heads"] = encoder.num_heads 178 | 179 | projectors = {} 180 | teacher_norms = {} 181 | 182 | for tname in args.teachers: 183 | proj_indim: int = encoder.embed_dim 184 | proj_outdim: int = TEACHER_CFG[tname]["num_features"] 185 | proj = build_projector(proj_indim, proj_outdim, args.proj_args) 186 | projectors[tname] = proj 187 | 188 | teacher_norms[tname] = TeacherNorm( 189 | TEACHER_CFG[tname]["token_types"], proj_outdim 190 | ) 191 | 192 | projectors = nn.ModuleDict(projectors) 193 | teacher_norms = nn.ModuleDict(teacher_norms) 194 | 195 | model = DUNE(encoder, projectors, teacher_norms) 196 | 197 | return model 198 | 199 | 200 | def load_student_from_checkpoint(ckpt_fname, ckpt_key="model"): 201 | assert os.path.isfile(ckpt_fname), ckpt_fname 202 | ckpt = torch.load(ckpt_fname, "cpu", weights_only=False) 203 | 204 | model = build_student_from_args(ckpt["args"]) 205 | 206 | state_dict = ckpt.get(ckpt_key, ckpt) 207 | state_dict = { 208 | k.replace("module.", "").replace("_orig_mod.", ""): v 209 | for k, v in state_dict.items() 210 | } 211 | model.load_state_dict(state_dict) 212 | 213 | iter = ckpt.get("iter", 0) 214 | 215 | logger.info( 216 | "Loaded student from checkpoint {} trained for {} iterations".format( 217 | ckpt_fname, iter 218 | ) 219 | ) 220 | 221 | return model, iter 222 | 223 | 224 | def load_dune_encoder_from_checkpoint(*args, **kwargs): 225 | """ 226 | Loads only the encoder part of the DUNE model from a checkpoint. 227 | """ 228 | return load_student_encoder_from_checkpoint(*args, **kwargs) 229 | 230 | 231 | def load_dune_from_checkpoint(*args, **kwargs): 232 | """ 233 | Loads the complete DUNE model from a checkpoint. 234 | """ 235 | return load_student_from_checkpoint(*args, **kwargs) 236 | -------------------------------------------------------------------------------- /model/common/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | from typing import Callable, List, Any, Tuple, Dict 12 | 13 | import torch 14 | from torch import nn, Tensor 15 | 16 | from .attention import Attention, MemEffAttention 17 | from .drop_path import DropPath 18 | from .layer_scale import LayerScale 19 | from .mlp import Mlp 20 | 21 | 22 | logger = logging.getLogger("dinov2") 23 | 24 | 25 | def get_layerscale(dim, init_values=None): 26 | return ( 27 | LayerScale(dim, init_values=init_values) 28 | if init_values is not None and init_values > 0.0 29 | else nn.Identity() 30 | ) 31 | 32 | 33 | class Block(nn.Module): 34 | def __init__( 35 | self, 36 | dim: int, 37 | num_heads: int, 38 | mlp_ratio: float = 4.0, 39 | qkv_bias: bool = False, 40 | proj_bias: bool = True, 41 | ffn_bias: bool = True, 42 | drop: float = 0.0, 43 | attn_drop: float = 0.0, 44 | layerscale_init=None, 45 | drop_path: float = 0.0, 46 | qk_norm: bool = False, 47 | act_layer: Callable[..., nn.Module] = nn.GELU, 48 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 49 | attn_class: Callable[..., nn.Module] = Attention, 50 | ffn_layer: Callable[..., nn.Module] = Mlp, 51 | ) -> None: 52 | super().__init__() 53 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 54 | self.norm1 = norm_layer(dim) 55 | self.attn = attn_class( 56 | dim, 57 | num_heads=num_heads, 58 | qkv_bias=qkv_bias, 59 | proj_bias=proj_bias, 60 | attn_drop=attn_drop, 61 | proj_drop=drop, 62 | qk_norm=qk_norm, 63 | ) 64 | self.ls1 = get_layerscale(dim, layerscale_init) 65 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 66 | 67 | self.norm2 = norm_layer(dim) 68 | mlp_hidden_dim = int(dim * mlp_ratio) 69 | self.mlp = ffn_layer( 70 | in_features=dim, 71 | hidden_features=mlp_hidden_dim, 72 | act_layer=act_layer, 73 | drop=drop, 74 | bias=ffn_bias, 75 | ) 76 | self.ls2 = get_layerscale(dim, layerscale_init) 77 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 78 | 79 | self.sample_drop_ratio = drop_path 80 | 81 | def forward(self, x: Tensor, return_attention=False) -> Tensor: 82 | def attn_residual_func(x: Tensor) -> Tensor: 83 | return self.ls1(self.attn(self.norm1(x))) 84 | 85 | def ffn_residual_func(x: Tensor) -> Tensor: 86 | return self.ls2(self.mlp(self.norm2(x))) 87 | 88 | if return_attention: 89 | return self.attn(self.norm1(x), return_attention=True) 90 | elif self.training and self.sample_drop_ratio > 0.1: 91 | # the overhead is compensated only for a drop path rate larger than 0.1 92 | x = drop_add_residual_stochastic_depth( 93 | x, 94 | residual_func=attn_residual_func, 95 | sample_drop_ratio=self.sample_drop_ratio, 96 | ) 97 | x = drop_add_residual_stochastic_depth( 98 | x, 99 | residual_func=ffn_residual_func, 100 | sample_drop_ratio=self.sample_drop_ratio, 101 | ) 102 | elif self.training and self.sample_drop_ratio > 0.0: 103 | x = x + self.drop_path1(attn_residual_func(x)) 104 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 105 | else: 106 | x = x + attn_residual_func(x) 107 | x = x + ffn_residual_func(x) 108 | return x 109 | 110 | 111 | def drop_add_residual_stochastic_depth( 112 | x: Tensor, 113 | residual_func: Callable[[Tensor], Tensor], 114 | sample_drop_ratio: float = 0.0, 115 | ) -> Tensor: 116 | # 1) extract subset using permutation 117 | b, n, d = x.shape 118 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 119 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 120 | x_subset = x[brange] 121 | 122 | # 2) apply residual_func to get residual 123 | residual = residual_func(x_subset) 124 | 125 | x_flat = x.flatten(1) 126 | residual = residual.flatten(1) 127 | 128 | residual_scale_factor = b / sample_subset_size 129 | 130 | # 3) add the residual 131 | x_plus_residual = torch.index_add( 132 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor 133 | ) 134 | return x_plus_residual.view_as(x) 135 | 136 | 137 | def get_branges_scales(x, sample_drop_ratio=0.0): 138 | b, n, d = x.shape 139 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 140 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 141 | residual_scale_factor = b / sample_subset_size 142 | return brange, residual_scale_factor 143 | 144 | 145 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 146 | if scaling_vector is None: 147 | x_flat = x.flatten(1) 148 | residual = residual.flatten(1) 149 | x_plus_residual = torch.index_add( 150 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor 151 | ) 152 | else: 153 | x_plus_residual = scaled_index_add( 154 | x, 155 | brange, 156 | residual.to(dtype=x.dtype), 157 | scaling=scaling_vector, 158 | alpha=residual_scale_factor, 159 | ) 160 | return x_plus_residual 161 | 162 | 163 | attn_bias_cache: Dict[Tuple, Any] = {} 164 | 165 | 166 | def get_attn_bias_and_cat(x_list, branges=None): 167 | """ 168 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 169 | """ 170 | batch_sizes = ( 171 | [b.shape[0] for b in branges] 172 | if branges is not None 173 | else [x.shape[0] for x in x_list] 174 | ) 175 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 176 | if all_shapes not in attn_bias_cache.keys(): 177 | seqlens = [] 178 | for b, x in zip(batch_sizes, x_list): 179 | for _ in range(b): 180 | seqlens.append(x.shape[1]) 181 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 182 | attn_bias._batch_sizes = batch_sizes 183 | attn_bias_cache[all_shapes] = attn_bias 184 | 185 | if branges is not None: 186 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( 187 | 1, -1, x_list[0].shape[-1] 188 | ) 189 | else: 190 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 191 | cat_tensors = torch.cat(tensors_bs1, dim=1) 192 | 193 | return attn_bias_cache[all_shapes], cat_tensors 194 | 195 | 196 | def drop_add_residual_stochastic_depth_list( 197 | x_list: List[Tensor], 198 | residual_func: Callable[[Tensor, Any], Tensor], 199 | sample_drop_ratio: float = 0.0, 200 | scaling_vector=None, 201 | ) -> Tensor: 202 | # 1) generate random set of indices for dropping samples in the batch 203 | branges_scales = [ 204 | get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list 205 | ] 206 | branges = [s[0] for s in branges_scales] 207 | residual_scale_factors = [s[1] for s in branges_scales] 208 | 209 | # 2) get attention bias and index+concat the tensors 210 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 211 | 212 | # 3) apply residual_func to get residual, and split the result 213 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 214 | 215 | outputs = [] 216 | for x, brange, residual, residual_scale_factor in zip( 217 | x_list, branges, residual_list, residual_scale_factors 218 | ): 219 | outputs.append( 220 | add_residual( 221 | x, brange, residual, residual_scale_factor, scaling_vector 222 | ).view_as(x) 223 | ) 224 | return outputs 225 | 226 | 227 | class NestedTensorBlock(Block): 228 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 229 | """ 230 | x_list contains a list of tensors to nest together and run 231 | """ 232 | assert isinstance(self.attn, MemEffAttention) 233 | 234 | if self.training and self.sample_drop_ratio > 0.0: 235 | 236 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 237 | return self.attn(self.norm1(x), attn_bias=attn_bias) 238 | 239 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 240 | return self.mlp(self.norm2(x)) 241 | 242 | x_list = drop_add_residual_stochastic_depth_list( 243 | x_list, 244 | residual_func=attn_residual_func, 245 | sample_drop_ratio=self.sample_drop_ratio, 246 | scaling_vector=( 247 | self.ls1.gamma if isinstance(self.ls1, LayerScale) else None 248 | ), 249 | ) 250 | x_list = drop_add_residual_stochastic_depth_list( 251 | x_list, 252 | residual_func=ffn_residual_func, 253 | sample_drop_ratio=self.sample_drop_ratio, 254 | scaling_vector=( 255 | self.ls2.gamma if isinstance(self.ls1, LayerScale) else None 256 | ), 257 | ) 258 | return x_list 259 | else: 260 | 261 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 262 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 263 | 264 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 265 | return self.ls2(self.mlp(self.norm2(x))) 266 | 267 | attn_bias, x = get_attn_bias_and_cat(x_list) 268 | x = x + attn_residual_func(x, attn_bias=attn_bias) 269 | x = x + ffn_residual_func(x) 270 | return attn_bias.split(x) 271 | 272 | def forward(self, x_or_x_list, **kwargs): 273 | if isinstance(x_or_x_list, Tensor): 274 | return super().forward(x_or_x_list, **kwargs) 275 | else: 276 | raise ValueError( 277 | "Type of x_or_x_list ({}) is not recognized".format(type(x_or_x_list)) 278 | ) 279 | -------------------------------------------------------------------------------- /teachers/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | from typing import Callable, List, Any, Tuple, Dict 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention, MemEffAttention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 28 | try: 29 | if XFORMERS_ENABLED: 30 | from xformers.ops import fmha, scaled_index_add, index_select_cat 31 | 32 | XFORMERS_AVAILABLE = True 33 | warnings.warn("xFormers is available (Block)") 34 | else: 35 | warnings.warn("xFormers is disabled (Block)") 36 | raise ImportError 37 | except ImportError: 38 | XFORMERS_AVAILABLE = False 39 | 40 | warnings.warn("xFormers is not available (Block)") 41 | 42 | 43 | def get_layerscale(dim, init_values=None): 44 | return ( 45 | LayerScale(dim, init_values=init_values) 46 | if init_values is not None and init_values > 0.0 47 | else nn.Identity() 48 | ) 49 | 50 | 51 | class Block(nn.Module): 52 | def __init__( 53 | self, 54 | dim: int, 55 | num_heads: int, 56 | mlp_ratio: float = 4.0, 57 | qkv_bias: bool = False, 58 | proj_bias: bool = True, 59 | ffn_bias: bool = True, 60 | drop: float = 0.0, 61 | attn_drop: float = 0.0, 62 | init_values=None, 63 | drop_path: float = 0.0, 64 | qk_norm: bool = False, 65 | act_layer: Callable[..., nn.Module] = nn.GELU, 66 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 67 | attn_class: Callable[..., nn.Module] = Attention, 68 | ffn_layer: Callable[..., nn.Module] = Mlp, 69 | ) -> None: 70 | super().__init__() 71 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 72 | self.norm1 = norm_layer(dim) 73 | self.attn = attn_class( 74 | dim, 75 | num_heads=num_heads, 76 | qkv_bias=qkv_bias, 77 | proj_bias=proj_bias, 78 | attn_drop=attn_drop, 79 | proj_drop=drop, 80 | qk_norm=qk_norm, 81 | ) 82 | self.ls1 = get_layerscale(dim, init_values) 83 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 84 | 85 | self.norm2 = norm_layer(dim) 86 | mlp_hidden_dim = int(dim * mlp_ratio) 87 | self.mlp = ffn_layer( 88 | in_features=dim, 89 | hidden_features=mlp_hidden_dim, 90 | act_layer=act_layer, 91 | drop=drop, 92 | bias=ffn_bias, 93 | ) 94 | self.ls2 = get_layerscale(dim, init_values) 95 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 96 | 97 | self.sample_drop_ratio = drop_path 98 | 99 | def forward(self, x: Tensor, return_attention=False) -> Tensor: 100 | def attn_residual_func(x: Tensor) -> Tensor: 101 | return self.ls1(self.attn(self.norm1(x))) 102 | 103 | def ffn_residual_func(x: Tensor) -> Tensor: 104 | return self.ls2(self.mlp(self.norm2(x))) 105 | 106 | if return_attention: 107 | return self.attn(self.norm1(x), return_attention=True) 108 | elif self.training and self.sample_drop_ratio > 0.1: 109 | # the overhead is compensated only for a drop path rate larger than 0.1 110 | x = drop_add_residual_stochastic_depth( 111 | x, 112 | residual_func=attn_residual_func, 113 | sample_drop_ratio=self.sample_drop_ratio, 114 | ) 115 | x = drop_add_residual_stochastic_depth( 116 | x, 117 | residual_func=ffn_residual_func, 118 | sample_drop_ratio=self.sample_drop_ratio, 119 | ) 120 | elif self.training and self.sample_drop_ratio > 0.0: 121 | x = x + self.drop_path1(attn_residual_func(x)) 122 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 123 | else: 124 | x = x + attn_residual_func(x) 125 | x = x + ffn_residual_func(x) 126 | return x 127 | 128 | 129 | def drop_add_residual_stochastic_depth( 130 | x: Tensor, 131 | residual_func: Callable[[Tensor], Tensor], 132 | sample_drop_ratio: float = 0.0, 133 | ) -> Tensor: 134 | # 1) extract subset using permutation 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | x_subset = x[brange] 139 | 140 | # 2) apply residual_func to get residual 141 | residual = residual_func(x_subset) 142 | 143 | x_flat = x.flatten(1) 144 | residual = residual.flatten(1) 145 | 146 | residual_scale_factor = b / sample_subset_size 147 | 148 | # 3) add the residual 149 | x_plus_residual = torch.index_add( 150 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor 151 | ) 152 | return x_plus_residual.view_as(x) 153 | 154 | 155 | def get_branges_scales(x, sample_drop_ratio=0.0): 156 | b, n, d = x.shape 157 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 158 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 159 | residual_scale_factor = b / sample_subset_size 160 | return brange, residual_scale_factor 161 | 162 | 163 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 164 | if scaling_vector is None: 165 | x_flat = x.flatten(1) 166 | residual = residual.flatten(1) 167 | x_plus_residual = torch.index_add( 168 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor 169 | ) 170 | else: 171 | x_plus_residual = scaled_index_add( 172 | x, 173 | brange, 174 | residual.to(dtype=x.dtype), 175 | scaling=scaling_vector, 176 | alpha=residual_scale_factor, 177 | ) 178 | return x_plus_residual 179 | 180 | 181 | attn_bias_cache: Dict[Tuple, Any] = {} 182 | 183 | 184 | def get_attn_bias_and_cat(x_list, branges=None): 185 | """ 186 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 187 | """ 188 | batch_sizes = ( 189 | [b.shape[0] for b in branges] 190 | if branges is not None 191 | else [x.shape[0] for x in x_list] 192 | ) 193 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 194 | if all_shapes not in attn_bias_cache.keys(): 195 | seqlens = [] 196 | for b, x in zip(batch_sizes, x_list): 197 | for _ in range(b): 198 | seqlens.append(x.shape[1]) 199 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 200 | attn_bias._batch_sizes = batch_sizes 201 | attn_bias_cache[all_shapes] = attn_bias 202 | 203 | if branges is not None: 204 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( 205 | 1, -1, x_list[0].shape[-1] 206 | ) 207 | else: 208 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 209 | cat_tensors = torch.cat(tensors_bs1, dim=1) 210 | 211 | return attn_bias_cache[all_shapes], cat_tensors 212 | 213 | 214 | def drop_add_residual_stochastic_depth_list( 215 | x_list: List[Tensor], 216 | residual_func: Callable[[Tensor, Any], Tensor], 217 | sample_drop_ratio: float = 0.0, 218 | scaling_vector=None, 219 | ) -> Tensor: 220 | # 1) generate random set of indices for dropping samples in the batch 221 | branges_scales = [ 222 | get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list 223 | ] 224 | branges = [s[0] for s in branges_scales] 225 | residual_scale_factors = [s[1] for s in branges_scales] 226 | 227 | # 2) get attention bias and index+concat the tensors 228 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 229 | 230 | # 3) apply residual_func to get residual, and split the result 231 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 232 | 233 | outputs = [] 234 | for x, brange, residual, residual_scale_factor in zip( 235 | x_list, branges, residual_list, residual_scale_factors 236 | ): 237 | outputs.append( 238 | add_residual( 239 | x, brange, residual, residual_scale_factor, scaling_vector 240 | ).view_as(x) 241 | ) 242 | return outputs 243 | 244 | 245 | class NestedTensorBlock(Block): 246 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 247 | """ 248 | x_list contains a list of tensors to nest together and run 249 | """ 250 | assert isinstance(self.attn, MemEffAttention) 251 | 252 | if self.training and self.sample_drop_ratio > 0.0: 253 | 254 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 255 | return self.attn(self.norm1(x), attn_bias=attn_bias) 256 | 257 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 258 | return self.mlp(self.norm2(x)) 259 | 260 | x_list = drop_add_residual_stochastic_depth_list( 261 | x_list, 262 | residual_func=attn_residual_func, 263 | sample_drop_ratio=self.sample_drop_ratio, 264 | scaling_vector=( 265 | self.ls1.gamma if isinstance(self.ls1, LayerScale) else None 266 | ), 267 | ) 268 | x_list = drop_add_residual_stochastic_depth_list( 269 | x_list, 270 | residual_func=ffn_residual_func, 271 | sample_drop_ratio=self.sample_drop_ratio, 272 | scaling_vector=( 273 | self.ls2.gamma if isinstance(self.ls1, LayerScale) else None 274 | ), 275 | ) 276 | return x_list 277 | else: 278 | 279 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 280 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 281 | 282 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 283 | return self.ls2(self.mlp(self.norm2(x))) 284 | 285 | attn_bias, x = get_attn_bias_and_cat(x_list) 286 | x = x + attn_residual_func(x, attn_bias=attn_bias) 287 | x = x + ffn_residual_func(x) 288 | return attn_bias.split(x) 289 | 290 | def forward(self, x_or_x_list, **kwargs): 291 | if isinstance(x_or_x_list, Tensor): 292 | return super().forward(x_or_x_list, **kwargs) 293 | elif isinstance(x_or_x_list, list): 294 | if not XFORMERS_AVAILABLE: 295 | raise AssertionError("xFormers is required for using nested tensors") 296 | return self.forward_nested(x_or_x_list) 297 | else: 298 | raise ValueError( 299 | "Type of x_or_x_list ({}) is not recognized".format(type(x_or_x_list)) 300 | ) 301 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |