├── .gitignore ├── ModelCreated ├── ViTAS │ ├── ViTAS_model.py │ ├── ViTAS_module │ │ ├── PAFM_fuse │ │ │ └── PAFM.py │ │ ├── SDFA_fuse │ │ │ └── SDFA.py │ │ ├── Uni_CA │ │ │ ├── CA_transformer.py │ │ │ ├── CA_utils.py │ │ │ ├── attention.py │ │ │ └── position.py │ │ └── VFM_fuse │ │ │ └── VFM.py │ └── args │ │ └── ViTAS_args.py └── ViTASIGEV │ ├── ViTASIGEV.py │ └── igev │ └── igev_model.py ├── README.md ├── model_pack ├── IGEV_Stereo │ ├── core │ │ ├── __init__.py │ │ ├── extractor.py │ │ ├── geometry.py │ │ ├── igev_stereo.py │ │ ├── stereo_datasets.py │ │ ├── submodule.py │ │ ├── update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ ├── demo_imgs.py │ ├── demo_video.py │ ├── evaluate_stereo.py │ ├── save_disp.py │ └── train_stereo.py └── dinoV2 │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── MODEL_CARD.md │ ├── README.md │ ├── conda-extras.yaml │ ├── conda.yaml │ ├── dinov2 │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── vitb14_pretrain.yaml │ │ │ ├── vitb14_reg4_pretrain.yaml │ │ │ ├── vitg14_pretrain.yaml │ │ │ ├── vitg14_reg4_pretrain.yaml │ │ │ ├── vitl14_pretrain.yaml │ │ │ ├── vitl14_reg4_pretrain.yaml │ │ │ ├── vitl16_pretrain.yaml │ │ │ ├── vits14_pretrain.yaml │ │ │ └── vits14_reg4_pretrain.yaml │ │ ├── ssl_default_config.yaml │ │ └── train │ │ │ ├── vitg14.yaml │ │ │ ├── vitl14.yaml │ │ │ └── vitl16_short.yaml │ ├── data │ │ ├── __init__.py │ │ ├── adapters.py │ │ ├── augmentations.py │ │ ├── collate.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── extended.py │ │ │ ├── image_net.py │ │ │ └── image_net_22k.py │ │ ├── loaders.py │ │ ├── masking.py │ │ ├── samplers.py │ │ └── transforms.py │ ├── distributed │ │ └── __init__.py │ ├── eval │ │ ├── __init__.py │ │ ├── depth │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── vision_transformer.py │ │ │ │ ├── builder.py │ │ │ │ ├── decode_heads │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── decode_head.py │ │ │ │ │ ├── dpt_head.py │ │ │ │ │ └── linear_head.py │ │ │ │ ├── depther │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ └── encoder_decoder.py │ │ │ │ └── losses │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── gradientloss.py │ │ │ │ │ └── sigloss.py │ │ │ └── ops │ │ │ │ ├── __init__.py │ │ │ │ └── wrappers.py │ │ ├── knn.py │ │ ├── linear.py │ │ ├── log_regression.py │ │ ├── metrics.py │ │ ├── segmentation │ │ │ ├── __init__.py │ │ │ ├── hooks │ │ │ │ ├── __init__.py │ │ │ │ └── optimizer.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── vision_transformer.py │ │ │ │ └── decode_heads │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── linear_head.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── colormaps.py │ │ ├── segmentation_m2f │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── anchor │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── builder.py │ │ │ │ │ └── point_generator.py │ │ │ │ ├── box │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── builder.py │ │ │ │ │ └── samplers │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base_sampler.py │ │ │ │ │ │ ├── mask_pseudo_sampler.py │ │ │ │ │ │ ├── mask_sampling_result.py │ │ │ │ │ │ └── sampling_result.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dist_utils.py │ │ │ │ │ └── misc.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── adapter_modules.py │ │ │ │ │ ├── drop_path.py │ │ │ │ │ ├── vit.py │ │ │ │ │ └── vit_adapter.py │ │ │ │ ├── builder.py │ │ │ │ ├── decode_heads │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── mask2former_head.py │ │ │ │ ├── losses │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cross_entropy_loss.py │ │ │ │ │ ├── dice_loss.py │ │ │ │ │ └── match_costs.py │ │ │ │ ├── plugins │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── msdeformattn_pixel_decoder.py │ │ │ │ ├── segmentors │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── encoder_decoder_mask2former.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assigner.py │ │ │ │ │ ├── point_sample.py │ │ │ │ │ ├── positional_encoding.py │ │ │ │ │ └── transformer.py │ │ │ └── ops │ │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ └── utils.py │ ├── fsdp │ │ └── __init__.py │ ├── hub │ │ ├── __init__.py │ │ ├── backbones.py │ │ ├── classifiers.py │ │ ├── depth │ │ │ ├── __init__.py │ │ │ ├── decode_heads.py │ │ │ ├── encoder_decoder.py │ │ │ └── ops.py │ │ ├── depthers.py │ │ └── utils.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── dino_head.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── logging │ │ ├── __init__.py │ │ └── helpers.py │ ├── loss │ │ ├── __init__.py │ │ ├── dino_clstoken_loss.py │ │ ├── ibot_patch_loss.py │ │ └── koleo_loss.py │ ├── models │ │ ├── __init__.py │ │ └── vision_transformer.py │ ├── run │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── knn.py │ │ │ ├── linear.py │ │ │ └── log_regression.py │ │ ├── submit.py │ │ └── train │ │ │ └── train.py │ ├── train │ │ ├── __init__.py │ │ ├── ssl_meta_arch.py │ │ └── train.py │ └── utils │ │ ├── __init__.py │ │ ├── cluster.py │ │ ├── config.py │ │ ├── dtype.py │ │ ├── param_groups.py │ │ └── utils.py │ ├── hubconf.py │ ├── notebooks │ ├── depth_estimation.ipynb │ └── semantic_segmentation.ipynb │ ├── pyproject.toml │ ├── requirements-dev.txt │ ├── requirements-extras.txt │ ├── requirements.txt │ ├── scripts │ └── lint.sh │ ├── setup.cfg │ └── setup.py ├── requirements.txt └── toolkit ├── args ├── args_default.py └── model_args.py ├── data_loader ├── __init__.py ├── dataloader.py ├── dataset_function.py └── transforms.py ├── function ├── base_function.py ├── evaluator.py └── models.py ├── main.py └── torch_lightning ├── data_modules └── custom.py ├── lightning_function.py └── pl_modules ├── ViTASIGEV.py ├── evaluate.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | *.pth 3 | *.rar 4 | *.zip 5 | *.png 6 | *.jpg 7 | *.out 8 | *.log 9 | *.pyc 10 | *.mge 11 | *.pickle 12 | *.pkl 13 | *.pyc 14 | *.mge 15 | *.npz 16 | *.npy 17 | *.ckpt 18 | *.th 19 | *.jpeg 20 | *.pfm 21 | *.pfm 22 | /toolkit/models/* 23 | /toolkit/generate_images/* 24 | /datasets/* 25 | /toolkit/torch_lightning/ckpts/* 26 | /toolkit/lightning_logs/* 27 | -------------------------------------------------------------------------------- /ModelCreated/ViTAS/ViTAS_module/Uni_CA/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /ModelCreated/ViTAS/ViTAS_module/VFM_fuse/VFM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from ModelCreated.ViTAS.ViTAS_module.Uni_CA.CA_utils import feature_add_position_cross_feature 4 | from ModelCreated.ViTAS.ViTAS_module.Uni_CA.CA_transformer import FeatureFuseTransformer 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 12 | return x 13 | return tuple(repeat(x, n)) 14 | return parse 15 | to_2tuple = _ntuple(2) 16 | 17 | class VFM(nn.Module): 18 | def __init__(self, 19 | channels = [48,64,192,160], 20 | num_heads = 4, 21 | UsePatch = False, 22 | ): 23 | super().__init__() 24 | self.channels = channels 25 | 26 | self.CA = nn.ModuleList([FeatureFuseTransformer(1,self.channels[i],1,2,self.channels[i+1]) for i in range(len(self.channels)-1)]) # i = 0,1,2 27 | 28 | self.conv1 = nn.ModuleList([nn.Conv2d(self.channels[i], self.channels[i], kernel_size=(1,1), stride=1) for i in range(len(self.channels)-1)]) 29 | self.conv3 = nn.ModuleList([nn.Conv2d(self.channels[i], self.channels[i], kernel_size=(3,3), stride=1,padding=1,padding_mode='replicate') for i in range(len(self.channels)-1)]) 30 | 31 | def fuse(self,i,x_l, x_r, x_l_high, x_r_high): # i=3,2,1 32 | 33 | x_l = F.interpolate(x_l, scale_factor=2., mode='bilinear', align_corners=True) 34 | x_r = F.interpolate(x_r, scale_factor=2., mode='bilinear', align_corners=True) 35 | 36 | x_l_high = self.conv1[i-1](x_l_high) 37 | x_r_high = self.conv1[i-1](x_r_high) 38 | 39 | x_l, x_r, x_l_high, x_r_high, = feature_add_position_cross_feature(x_l, x_r, x_l_high, x_r_high,self.channels[i],self.channels[i-1]) 40 | 41 | x_l_out = self.CA[i-1](x_l_high,x_l,attn_type='self_swin2d_cross_swin1d',attn_num_splits=4) 42 | x_r_out = self.CA[i-1](x_r_high,x_r,attn_type='self_swin2d_cross_swin1d',attn_num_splits=4) 43 | 44 | x_l_out = self.conv3[i-1](x_l_out) 45 | x_r_out = self.conv3[i-1](x_r_out) 46 | return x_l_out,x_r_out 47 | -------------------------------------------------------------------------------- /ModelCreated/ViTAS/args/ViTAS_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def config_ViTAS_args(ViTAS_dic): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--ViTAS_channel', type=int, default = 1024, help='channels of the feature from VFM') 6 | parser.add_argument('--scales', type=int, nargs='+', default = [4,8,16,32], help='select the resolution of the final outputs') # undone 7 | parser.add_argument('--pre_channels', type=int, nargs='+', default = [64,128,256,256], help='select the channels of the mid-features after preprocess') 8 | parser.add_argument('--VFM_type', type=str, default = 'DINOv2', choices=['DINOv2'],help='args for load VFM') 9 | parser.add_argument('--out_channels', type=int, nargs='+', default = [48,64,192,256], help='select the channels of the final outputs') 10 | parser.add_argument('--attn_splits_list', type=int, nargs='+', default = [8,4,2,2], help='select the split for swin-transformer') 11 | parser.add_argument('--CA_layers', type=int, default = 2, help='select the number of CA blocks') 12 | parser.add_argument('--wo_fuse', type=bool, default = False, help='if need the fuse module') 13 | parser.add_argument('--wo_SDM', type=bool, default = False, help='if need the SDM module') 14 | parser.add_argument('--ViTAS_fuse', type=str, default = 'PAFM', help='select the fuse method') 15 | parser.add_argument('--ViTAS_model', type=str, default = 'vit_l', choices=['vit_l','vit_b','vit_s'], help='select the ViTAS_model') 16 | parser.add_argument('--ViTAS_hooks', type=int, nargs='+', default = [5,11,17,23], help='select the layers of outputs from VFM') 17 | parser.add_argument('--ViTAS_unfreeze', default = '4', type=str, help='if freeze the last n blocks of VFM') 18 | parser.add_argument('--ViTAS_pure', default = False, type=bool, help='if load pure VFM') 19 | parser.add_argument('--fuse_dic', type=dict, default = {}, help='args for fuse module') 20 | args = parser.parse_args() 21 | args = args_refine(args,**ViTAS_dic) 22 | return args 23 | 24 | def config_ViTASUni_args(ViTAS_dic): 25 | args = config_ViTAS_args(ViTAS_dic) 26 | args.pre_channels = [128,256,512,512] 27 | args.out_channels = [128,128,256,512] 28 | return args 29 | 30 | def config_ViTASIGEV_args(ViTAS_dic): 31 | args = config_ViTAS_args(ViTAS_dic) 32 | args.pre_channels = [64,128,256,256] 33 | args.out_channels = [48,64,192,160] 34 | return args 35 | 36 | def config_ViTASCre_args(ViTAS_dic): 37 | args = config_ViTAS_args(ViTAS_dic) 38 | args.pre_channels = [256,256,256,256] 39 | args.out_channels = [256,256,256,256] 40 | # args.pre_channels = [256,256,256,256] 41 | # args.out_channels = [256,256,256,256] 42 | return args 43 | 44 | def config_ViTASCroco_args(ViTAS_dic): 45 | args = config_ViTAS_args(ViTAS_dic) 46 | args.ViTAS_pure = True 47 | assert args.ViTAS_channel == 1024 48 | return args 49 | 50 | 51 | ViTAS_config = {'DINOv2':{'vit_l': {'hook_select':[5,11,17,23],'channel':1024,'hook_check':23}, 52 | 'vit_b': {'hook_select':[2,5,8,11],'channel':768,'hook_check':11},}, 53 | } 54 | 55 | def ViTAS_hooks_init(args): # select the init VFM hooks 56 | assert args.ViTAS_model in ViTAS_config[args.VFM_type].keys() 57 | args.ViTAS_hooks = ViTAS_config[args.VFM_type][args.ViTAS_model]['hook_select'] 58 | return args 59 | 60 | def ViTAS_channels_init(args): # select the dinoV2 feature channels 61 | assert args.ViTAS_model in ViTAS_config[args.VFM_type].keys() 62 | args.ViTAS_channel = ViTAS_config[args.VFM_type][args.ViTAS_model]['channel'] 63 | return args 64 | 65 | def check_dino_hooks(args): # check if the hooks is valid 66 | assert args.ViTAS_model in ViTAS_config[args.VFM_type].keys() 67 | assert args.ViTAS_hooks[-1] <= ViTAS_config[args.VFM_type][args.ViTAS_model]['hook_check'] 68 | 69 | def args_refine(args,VFM_type,ViTAS_model,ViTAS_hooks,ViTAS_unfreeze,wo_fuse,wo_SDM,ViTAS_fuse,ViTAS_fuse_patch,ViTAS_fuse_weight,CA_layers): 70 | if not ViTAS_model is None: 71 | args.ViTAS_model = ViTAS_model 72 | args = ViTAS_hooks_init(args) 73 | args = ViTAS_channels_init(args) 74 | if not wo_fuse is None: 75 | args.wo_fuse = wo_fuse 76 | if not wo_SDM is None: 77 | args.wo_SDM = wo_SDM 78 | if not VFM_type is None: 79 | args.VFM_type = VFM_type 80 | if not ViTAS_hooks is None: 81 | args.ViTAS_hooks = ViTAS_hooks 82 | if not ViTAS_unfreeze is None: 83 | args.ViTAS_unfreeze = ViTAS_unfreeze 84 | if not ViTAS_fuse is None: 85 | args.ViTAS_fuse = ViTAS_fuse 86 | if not CA_layers is None: 87 | args.CA_layers = CA_layers 88 | if not ViTAS_fuse_patch is None: 89 | args.fuse_dic['UsePatch'] = ViTAS_fuse_patch # use patch in PAFM module 90 | if not ViTAS_fuse_patch is None: 91 | args.fuse_dic['UseWeight'] = ViTAS_fuse_weight # use DANE weight in PAFM module 92 | assert args.ViTAS_fuse in ['SDFA','PAFM','VFM'] 93 | assert args.ViTAS_unfreeze in ['0','1','2','3','4','5'], str(args.ViTAS_unfreeze) 94 | check_dino_hooks(args) 95 | return args -------------------------------------------------------------------------------- /ModelCreated/ViTASIGEV/ViTASIGEV.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ModelCreated.ViTAS.args.ViTAS_args import config_ViTASIGEV_args 3 | from ModelCreated.ViTAS.ViTAS_model import ViTASBaseModel 4 | from ModelCreated.ViTASIGEV.igev.igev_model import IGEVStereo,config_IGEV_args 5 | 6 | def get_parameter_number(model,name=None): 7 | total_num = sum(p.numel() for p in model.parameters()) 8 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | print(name, ' Total:', total_num, 'Trainable:', trainable_num) 10 | # return {'Total': total_num, 'Trainable': trainable_num} 11 | 12 | class ViTASIGEVModel(nn.Module): 13 | def __init__(self,ViTAS_dic): 14 | super().__init__() 15 | self.ViTAS = self.load_ViTAS(ViTAS_dic) 16 | get_parameter_number(self.ViTAS,'ViTAS') 17 | self.igev = self.load_IGEV() 18 | 19 | def load_ViTAS(self,ViTAS_dic): 20 | args = config_ViTASIGEV_args(ViTAS_dic) 21 | model = ViTASBaseModel(**vars(args)) 22 | return model 23 | 24 | def load_IGEV(self): 25 | igev_args = config_IGEV_args() 26 | model = IGEVStereo(igev_args) 27 | return model 28 | 29 | def forward(self, img1, img2, iters, test_mode=False): 30 | features_out = self.ViTAS(img1,img2) # 31 | feature1_list = features_out['feature0_out_list'] 32 | feature2_list = features_out['feature1_out_list'] 33 | 34 | feature1_list.reverse() 35 | feature2_list.reverse() 36 | pred_disp = self.igev(feature1_list, feature2_list, img1, img2, iters=iters, test_mode=test_mode) 37 | return pred_disp 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViTAS 2 | This is the official repo for our work 'Playing to Vision Foundation Model's Strengths in Stereo Matching'. 3 | [Paper](https://arxiv.org/abs/2404.06261) 4 | 5 | [Demo Video](https://www.youtube.com/watch?v=vL3xjHFYgP0) 6 | 7 | ## Setup 8 | We built and ran the repo with CUDA 11.8, Python 3.9.0, and Pytorch 2.1.0. For using this repo, please follow the instructions below: 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | If you have any problem with installing xFormers package, please follow the guidance in [DINOv2](https://github.com/facebookresearch/dinov2). 14 | 15 | ## Pre-trained models 16 | 17 | Pretrained models leading to our [SoTA KITTI benchmark results](https://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo) can be downloaded from [google drive](https://drive.google.com/file/d/15hwxnvN53PhV3-7GGH0N_AM4XlD9DEsQ/view?usp=sharing), and is supposed to be under dir: `toolkit/models/ViTASIGEV`. 18 | 19 | Results of our KITTI benchmark results can be downloaded from [2012(google drive)](https://drive.google.com/file/d/1HxZIrBZvYjt8g4NOXisj3TxkQmrjvQ66/view?usp=drive_link) and [2015(google drive)](https://drive.google.com/file/d/15jCzdIa_2gxF7LLoulsdY3vwRNwBB6_t/view?usp=drive_link). 20 | 21 | ## Dataset Preparation 22 | To train/evaluate ViTAStereo, you will need to download the required datasets. 23 | 24 | * [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#:~:text=on%20Academic%20Torrents-,FlyingThings3D,-Driving) (Includes FlyingThings3D, Driving) 25 | * [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) 26 | * [Middlebury](https://vision.middlebury.edu/stereo/data/) 27 | * [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-test-data) 28 | * [KITTI 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo) 29 | * [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 30 | 31 | You can create symbolic links to wherever the datasets were downloaded in the `$root/datasets` folder: 32 | 33 | ```shell 34 | ln -s $YOUR_DATASET_ROOT datasets 35 | ``` 36 | 37 | Our folder structure is as follows: 38 | 39 | ``` 40 | ├── datasets 41 | ├── ETH3D 42 | │ ├── testing 43 | │ ├── lakeside_1l 44 |    │   └── ... 45 |    │   └── training 46 | │ ├── delivery_area_1l 47 |    │   └── ... 48 |    │ 49 | ├── KITTI 50 | │ ├── 2012 51 | │ │ ├── testing 52 |    │   │ └── training 53 |    │   └── 2015 54 | │ ├── testing 55 |    │   └── training 56 | ├── middlebury 57 | │ ├── 2005 58 | │ ├── 2006 59 | │ ├── 2014 60 | │ ├── 2021 61 | │   └── MiddEval3 62 | └── sceneflow 63 | ├── driving 64 | │ ├── 15mm_focallength 65 | │ │ ├── scene_backwards 66 | │ │ └── scene_fowwards 67 |    │ └── 35mm_focallength 68 | ├── flying 69 | │ ├── TRAIN 70 | │ │ ├── A 71 | │ │ ├── B 72 | │ │ └── C 73 |    │ └── Test 74 | ``` 75 | ## Training & Evaluation 76 | The demos of training and evaluating function of our ViTAS are integrated in the `toolkit/main.py`. 77 | 78 | ## Citation 79 | 80 | If you find our works useful in your research, please consider citing: 81 | 82 | ``` 83 | @article{li2024roadformer, 84 | title={Playing to Vision Foundation Model's Strengths in Stereo Matching}, 85 | author={Chuang-Wei Liu and Chen, Qijun and Fan, Rui}, 86 | journal={IEEE Transactions on Intelligent Vehicles}, 87 | year={2024}, 88 | publisher={IEEE}, 89 | note={{DOI}:{10.1109/TIV.2024.3467287}}, 90 | } 91 | ``` 92 | 93 | ## Acknowledgment 94 | 95 | Some of this repo come from [IGEV-Stereo](https://github.com/gangweiX/IGEV),[GMStereo](https://github.com/autonomousvision/unimatch), and [DINOv2](https://github.com/facebookresearch/dinov2). 96 | -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CwLiuzzZ/ViTAS/d2e82761e72cde8cbaae78c7117542d39330489d/model_pack/IGEV_Stereo/core/__init__.py -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/core/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('../model_pack/IGEV_Stereo') 5 | from core.utils.utils import bilinear_sampler 6 | 7 | 8 | class Combined_Geo_Encoding_Volume: 9 | def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, radius=4): 10 | self.num_levels = num_levels 11 | self.radius = radius 12 | self.geo_volume_pyramid = [] 13 | self.init_corr_pyramid = [] 14 | 15 | # all pairs correlation 16 | init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2) 17 | 18 | b, h, w, _, w2 = init_corr.shape 19 | b, c, d, h, w = geo_volume.shape 20 | geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d) 21 | 22 | init_corr = init_corr.reshape(b*h*w, 1, 1, w2) 23 | self.geo_volume_pyramid.append(geo_volume) 24 | self.init_corr_pyramid.append(init_corr) 25 | for i in range(self.num_levels-1): 26 | geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2]) 27 | self.geo_volume_pyramid.append(geo_volume) 28 | 29 | for i in range(self.num_levels-1): 30 | init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2]) 31 | self.init_corr_pyramid.append(init_corr) 32 | 33 | 34 | 35 | 36 | def __call__(self, disp, coords): 37 | r = self.radius 38 | b, _, h, w = disp.shape 39 | out_pyramid = [] 40 | for i in range(self.num_levels): 41 | geo_volume = self.geo_volume_pyramid[i] 42 | dx = torch.linspace(-r, r, 2*r+1) 43 | dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) 44 | x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i 45 | y0 = torch.zeros_like(x0) 46 | 47 | disp_lvl = torch.cat([x0,y0], dim=-1) 48 | geo_volume = bilinear_sampler(geo_volume, disp_lvl) 49 | geo_volume = geo_volume.view(b, h, w, -1) 50 | 51 | init_corr = self.init_corr_pyramid[i] 52 | init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + dx 53 | init_coords_lvl = torch.cat([init_x0,y0], dim=-1) 54 | init_corr = bilinear_sampler(init_corr, init_coords_lvl) 55 | init_corr = init_corr.view(b, h, w, -1) 56 | 57 | out_pyramid.append(geo_volume) 58 | out_pyramid.append(init_corr) 59 | out = torch.cat(out_pyramid, dim=-1) 60 | return out.permute(0, 3, 1, 2).contiguous().float() 61 | 62 | 63 | @staticmethod 64 | def corr(fmap1, fmap2): 65 | B, D, H, W1 = fmap1.shape 66 | _, _, _, W2 = fmap2.shape 67 | fmap1 = fmap1.view(B, D, H, W1) 68 | fmap2 = fmap2.view(B, D, H, W2) 69 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) 70 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 71 | return corr -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CwLiuzzZ/ViTAS/d2e82761e72cde8cbaae78c7117542d39330489d/model_pack/IGEV_Stereo/core/utils/__init__.py -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel', divis_by=8): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by 12 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | assert all((x.ndim == 4) for x in inputs) 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self, x): 23 | assert x.ndim == 4 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | 63 | # print("$$$55555", img.shape, coords.shape) 64 | xgrid, ygrid = coords.split([1,1], dim=-1) 65 | xgrid = 2*xgrid/(W-1) - 1 66 | 67 | # print("######88888", xgrid) 68 | assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem 69 | 70 | grid = torch.cat([xgrid, ygrid], dim=-1) 71 | # print("###37777", grid.shape) 72 | img = F.grid_sample(img, grid, align_corners=True) 73 | if mask: 74 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 75 | return img, mask.float() 76 | 77 | return img 78 | 79 | 80 | def coords_grid(batch, ht, wd): 81 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 82 | coords = torch.stack(coords[::-1], dim=0).float() 83 | return coords[None].repeat(batch, 1, 1, 1) 84 | 85 | 86 | def upflow8(flow, mode='bilinear'): 87 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 88 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 89 | 90 | def gauss_blur(input, N=5, std=1): 91 | B, D, H, W = input.shape 92 | x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2) 93 | unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2)) 94 | weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4) 95 | weights = weights.view(1,1,N,N).to(input) 96 | output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2) 97 | return output.view(B, D, H, W) -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/demo_imgs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | DEVICE = 'cuda' 4 | import os 5 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 6 | import argparse 7 | import glob 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from igev_stereo import IGEVStereo 13 | from utils.utils import InputPadder 14 | from PIL import Image 15 | from matplotlib import pyplot as plt 16 | import os 17 | import cv2 18 | 19 | def load_image(imfile): 20 | img = np.array(Image.open(imfile)).astype(np.uint8) 21 | img = torch.from_numpy(img).permute(2, 0, 1).float() 22 | return img[None].to(DEVICE) 23 | 24 | def demo(args): 25 | model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) 26 | model.load_state_dict(torch.load(args.restore_ckpt)) 27 | 28 | model = model.module 29 | model.to(DEVICE) 30 | model.eval() 31 | 32 | output_directory = Path(args.output_directory) 33 | output_directory.mkdir(exist_ok=True) 34 | 35 | with torch.no_grad(): 36 | left_images = sorted(glob.glob(args.left_imgs, recursive=True)) 37 | right_images = sorted(glob.glob(args.right_imgs, recursive=True)) 38 | print(f"Found {len(left_images)} images. Saving files to {output_directory}/") 39 | 40 | for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): 41 | image1 = load_image(imfile1) 42 | image2 = load_image(imfile2) 43 | 44 | padder = InputPadder(image1.shape, divis_by=32) 45 | image1, image2 = padder.pad(image1, image2) 46 | 47 | disp = model(image1, image2, iters=args.valid_iters, test_mode=True) 48 | disp = disp.cpu().numpy() 49 | disp = padder.unpad(disp) 50 | file_stem = imfile1.split('/')[-2] 51 | filename = os.path.join(output_directory, f"{file_stem}.png") 52 | plt.imsave(output_directory / f"{file_stem}.png", disp.squeeze(), cmap='jet') 53 | # disp = np.round(disp * 256).astype(np.uint16) 54 | # cv2.imwrite(filename, cv2.applyColorMap(cv2.convertScaleAbs(disp.squeeze(), alpha=0.01),cv2.COLORMAP_JET), [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth') 60 | parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays') 61 | 62 | parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="./demo-imgs/*/im0.png") 63 | parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="./demo-imgs/*/im1.png") 64 | 65 | # parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/Middlebury/trainingH/*/im0.png") 66 | # parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/Middlebury/trainingH/*/im1.png") 67 | # parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/ETH3D/two_view_training/*/im0.png") 68 | # parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/ETH3D/two_view_training/*/im1.png") 69 | parser.add_argument('--output_directory', help="directory to save output", default="./demo-output/") 70 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 71 | parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass') 72 | 73 | # Architecture choices 74 | parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") 75 | parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") 76 | parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") 77 | parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid") 78 | parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") 79 | parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") 80 | parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") 81 | parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") 82 | parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") 83 | 84 | args = parser.parse_args() 85 | 86 | Path(args.output_directory).mkdir(exist_ok=True, parents=True) 87 | 88 | demo(args) 89 | -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/demo_video.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import cv2 4 | import numpy as np 5 | import glob 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import torch 9 | from PIL import Image 10 | from igev_stereo import IGEVStereo 11 | import os 12 | import argparse 13 | from utils.utils import InputPadder 14 | torch.backends.cudnn.benchmark = True 15 | half_precision = True 16 | 17 | 18 | DEVICE = 'cuda' 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 20 | 21 | parser = argparse.ArgumentParser(description='Iterative Geometry Encoding Volume for Stereo Matching and Multi-View Stereo (IGEV-Stereo)') 22 | parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth') 23 | parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays') 24 | parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_02/data/*.png") 25 | parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_03/data/*.png") 26 | parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision') 27 | parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass') 28 | parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") 29 | parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") 30 | parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") 31 | parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid") 32 | parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") 33 | parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") 34 | parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") 35 | parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") 36 | parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") 37 | 38 | args = parser.parse_args() 39 | model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) 40 | model.load_state_dict(torch.load(args.restore_ckpt)) 41 | model = model.module 42 | model.to(DEVICE) 43 | model.eval() 44 | 45 | left_images = sorted(glob.glob(args.left_imgs, recursive=True)) 46 | right_images = sorted(glob.glob(args.right_imgs, recursive=True)) 47 | print(f"Found {len(left_images)} images.") 48 | 49 | 50 | def load_image(imfile): 51 | img = np.array(Image.open(imfile)).astype(np.uint8) 52 | img = torch.from_numpy(img).permute(2, 0, 1).float() 53 | return img[None].to(DEVICE) 54 | 55 | if __name__ == '__main__': 56 | 57 | fps_list = np.array([]) 58 | videoWrite = cv2.VideoWriter('./IGEV_Stereo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 10, (1242, 750)) 59 | for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): 60 | image1 = load_image(imfile1) 61 | image2 = load_image(imfile2) 62 | padder = InputPadder(image1.shape, divis_by=32) 63 | image1_pad, image2_pad = padder.pad(image1, image2) 64 | torch.cuda.synchronize() 65 | 66 | start = torch.cuda.Event(enable_timing=True) 67 | end = torch.cuda.Event(enable_timing=True) 68 | start.record() 69 | with torch.no_grad(): 70 | with torch.cuda.amp.autocast(enabled=half_precision): 71 | disp = model(image1_pad, image2_pad, iters=16, test_mode=True) 72 | disp = padder.unpad(disp) 73 | end.record() 74 | torch.cuda.synchronize() 75 | runtime = start.elapsed_time(end) 76 | fps = 1000/runtime 77 | fps_list = np.append(fps_list, fps) 78 | if len(fps_list) > 5: 79 | fps_list = fps_list[-5:] 80 | avg_fps = np.mean(fps_list) 81 | print('Stereo runtime: {:.3f}'.format(1000/avg_fps)) 82 | 83 | disp_np = (2*disp).data.cpu().numpy().squeeze().astype(np.uint8) 84 | disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA) 85 | image_np = np.array(Image.open(imfile1)).astype(np.uint8) 86 | out_img = np.concatenate((image_np, disp_np), 0) 87 | cv2.putText( 88 | out_img, 89 | "%.1f fps" % (avg_fps), 90 | (10, image_np.shape[0]+30), 91 | cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) 92 | cv2.imshow('img', out_img) 93 | cv2.waitKey(1) 94 | videoWrite.write(out_img) 95 | videoWrite.release() 96 | -------------------------------------------------------------------------------- /model_pack/IGEV_Stereo/save_disp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import glob 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from igev_stereo import IGEVStereo 11 | from utils.utils import InputPadder 12 | from PIL import Image 13 | from matplotlib import pyplot as plt 14 | import os 15 | import skimage.io 16 | import cv2 17 | 18 | 19 | DEVICE = 'cuda' 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 22 | 23 | def load_image(imfile): 24 | img = np.array(Image.open(imfile)).astype(np.uint8) 25 | img = torch.from_numpy(img).permute(2, 0, 1).float() 26 | return img[None].to(DEVICE) 27 | 28 | def demo(args): 29 | model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) 30 | model.load_state_dict(torch.load(args.restore_ckpt)) 31 | 32 | model = model.module 33 | model.to(DEVICE) 34 | model.eval() 35 | 36 | output_directory = Path(args.output_directory) 37 | output_directory.mkdir(exist_ok=True) 38 | 39 | with torch.no_grad(): 40 | left_images = sorted(glob.glob(args.left_imgs, recursive=True)) 41 | right_images = sorted(glob.glob(args.right_imgs, recursive=True)) 42 | print(f"Found {len(left_images)} images. Saving files to {output_directory}/") 43 | 44 | for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): 45 | image1 = load_image(imfile1) 46 | image2 = load_image(imfile2) 47 | padder = InputPadder(image1.shape, divis_by=32) 48 | image1, image2 = padder.pad(image1, image2) 49 | disp = model(image1, image2, iters=args.valid_iters, test_mode=True) 50 | disp = padder.unpad(disp) 51 | file_stem = os.path.join(output_directory, imfile1.split('/')[-1]) 52 | disp = disp.cpu().numpy().squeeze() 53 | disp = np.round(disp * 256).astype(np.uint16) 54 | skimage.io.imsave(file_stem, disp) 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth') 60 | parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays') 61 | parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2015/testing/image_2/*_10.png") 62 | parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2015/testing/image_3/*_10.png") 63 | # parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2012/testing/colored_0/*_10.png") 64 | # parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2012/testing/colored_1/*_10.png") 65 | parser.add_argument('--output_directory', help="directory to save output", default="output") 66 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 67 | parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass') 68 | 69 | # Architecture choices 70 | parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") 71 | parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") 72 | parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") 73 | parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid") 74 | parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") 75 | parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") 76 | parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") 77 | parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") 78 | parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") 79 | 80 | args = parser.parse_args() 81 | 82 | demo(args) 83 | -------------------------------------------------------------------------------- /model_pack/dinoV2/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /model_pack/dinoV2/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DINOv2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to DINOv2, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /model_pack/dinoV2/conda-extras.yaml: -------------------------------------------------------------------------------- 1 | name: dinov2-extras 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | - xformers 7 | - conda-forge 8 | dependencies: 9 | - python=3.9 10 | - pytorch::pytorch=2.0.0 11 | - pytorch::pytorch-cuda=11.7.0 12 | - pytorch::torchvision=0.15.0 13 | - omegaconf 14 | - torchmetrics=0.10.3 15 | - fvcore 16 | - iopath 17 | - xformers::xformers=0.0.18 18 | - pip 19 | - pip: 20 | - git+https://github.com/facebookincubator/submitit 21 | - --extra-index-url https://pypi.nvidia.com 22 | - cuml-cu11 23 | - mmcv-full==1.5.0 24 | - mmsegmentation==0.27.0 25 | -------------------------------------------------------------------------------- /model_pack/dinoV2/conda.yaml: -------------------------------------------------------------------------------- 1 | name: dinov2 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | - xformers 7 | - conda-forge 8 | dependencies: 9 | - python=3.9 10 | - pytorch::pytorch=2.0.0 11 | - pytorch::pytorch-cuda=11.7.0 12 | - pytorch::torchvision=0.15.0 13 | - omegaconf 14 | - torchmetrics=0.10.3 15 | - fvcore 16 | - iopath 17 | - xformers::xformers=0.0.18 18 | - pip 19 | - pip: 20 | - git+https://github.com/facebookincubator/submitit 21 | - --extra-index-url https://pypi.nvidia.com 22 | - cuml-cu11 23 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | __version__ = "0.0.1" 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import pathlib 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | def load_config(config_name: str): 12 | config_filename = config_name + ".yaml" 13 | return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) 14 | 15 | 16 | dinov2_default_config = load_config("ssl_default_config") 17 | 18 | 19 | def load_and_merge_config(config_name: str): 20 | default_config = OmegaConf.create(dinov2_default_config) 21 | loaded_config = load_config(config_name) 22 | return OmegaConf.merge(default_config, loaded_config) 23 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitb14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_base 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_base 3 | patch_size: 14 4 | num_register_tokens: 4 5 | interpolate_antialias: true 6 | interpolate_offset: 0.0 7 | crops: 8 | global_crops_size: 518 # this is to set up the position embeddings properly 9 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitg14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_giant2 3 | patch_size: 14 4 | ffn_layer: swiglufused 5 | crops: 6 | global_crops_size: 518 # this is to set up the position embeddings properly 7 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_giant2 3 | patch_size: 14 4 | ffn_layer: swiglufused 5 | num_register_tokens: 4 6 | interpolate_antialias: true 7 | interpolate_offset: 0.0 8 | crops: 9 | global_crops_size: 518 # this is to set up the position embeddings properly 10 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitl14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 14 4 | num_register_tokens: 4 5 | interpolate_antialias: true 6 | interpolate_offset: 0.0 7 | crops: 8 | global_crops_size: 518 # this is to set up the position embeddings properly 9 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vitl16_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 16 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vits14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_small 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/eval/vits14_reg4_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_small 3 | patch_size: 14 4 | num_register_tokens: 4 5 | interpolate_antialias: true 6 | interpolate_offset: 0.0 7 | crops: 8 | global_crops_size: 518 # this is to set up the position embeddings properly 9 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/ssl_default_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHTS: '' 3 | compute_precision: 4 | grad_scaler: true 5 | teacher: 6 | backbone: 7 | sharding_strategy: SHARD_GRAD_OP 8 | mixed_precision: 9 | param_dtype: fp16 10 | reduce_dtype: fp16 11 | buffer_dtype: fp32 12 | dino_head: 13 | sharding_strategy: SHARD_GRAD_OP 14 | mixed_precision: 15 | param_dtype: fp16 16 | reduce_dtype: fp16 17 | buffer_dtype: fp32 18 | ibot_head: 19 | sharding_strategy: SHARD_GRAD_OP 20 | mixed_precision: 21 | param_dtype: fp16 22 | reduce_dtype: fp16 23 | buffer_dtype: fp32 24 | student: 25 | backbone: 26 | sharding_strategy: SHARD_GRAD_OP 27 | mixed_precision: 28 | param_dtype: fp16 29 | reduce_dtype: fp16 30 | buffer_dtype: fp32 31 | dino_head: 32 | sharding_strategy: SHARD_GRAD_OP 33 | mixed_precision: 34 | param_dtype: fp16 35 | reduce_dtype: fp32 36 | buffer_dtype: fp32 37 | ibot_head: 38 | sharding_strategy: SHARD_GRAD_OP 39 | mixed_precision: 40 | param_dtype: fp16 41 | reduce_dtype: fp32 42 | buffer_dtype: fp32 43 | dino: 44 | loss_weight: 1.0 45 | head_n_prototypes: 65536 46 | head_bottleneck_dim: 256 47 | head_nlayers: 3 48 | head_hidden_dim: 2048 49 | koleo_loss_weight: 0.1 50 | ibot: 51 | loss_weight: 1.0 52 | mask_sample_probability: 0.5 53 | mask_ratio_min_max: 54 | - 0.1 55 | - 0.5 56 | separate_head: false 57 | head_n_prototypes: 65536 58 | head_bottleneck_dim: 256 59 | head_nlayers: 3 60 | head_hidden_dim: 2048 61 | train: 62 | batch_size_per_gpu: 64 63 | dataset_path: ImageNet:split=TRAIN 64 | output_dir: . 65 | saveckp_freq: 20 66 | seed: 0 67 | num_workers: 10 68 | OFFICIAL_EPOCH_LENGTH: 1250 69 | cache_dataset: true 70 | centering: "centering" # or "sinkhorn_knopp" 71 | student: 72 | arch: vit_large 73 | patch_size: 16 74 | drop_path_rate: 0.3 75 | layerscale: 1.0e-05 76 | drop_path_uniform: true 77 | pretrained_weights: '' 78 | ffn_layer: "mlp" 79 | block_chunks: 0 80 | qkv_bias: true 81 | proj_bias: true 82 | ffn_bias: true 83 | num_register_tokens: 0 84 | interpolate_antialias: false 85 | interpolate_offset: 0.1 86 | teacher: 87 | momentum_teacher: 0.992 88 | final_momentum_teacher: 1 89 | warmup_teacher_temp: 0.04 90 | teacher_temp: 0.07 91 | warmup_teacher_temp_epochs: 30 92 | optim: 93 | epochs: 100 94 | weight_decay: 0.04 95 | weight_decay_end: 0.4 96 | base_lr: 0.004 # learning rate for a batch size of 1024 97 | lr: 0. # will be set after applying scaling rule 98 | warmup_epochs: 10 99 | min_lr: 1.0e-06 100 | clip_grad: 3.0 101 | freeze_last_layer_epochs: 1 102 | scaling_rule: sqrt_wrt_1024 103 | patch_embed_lr_mult: 0.2 104 | layerwise_decay: 0.9 105 | adamw_beta1: 0.9 106 | adamw_beta2: 0.999 107 | crops: 108 | global_crops_scale: 109 | - 0.32 110 | - 1.0 111 | local_crops_number: 8 112 | local_crops_scale: 113 | - 0.05 114 | - 0.32 115 | global_crops_size: 224 116 | local_crops_size: 96 117 | evaluation: 118 | eval_period_iterations: 12500 119 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/train/vitg14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 12 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_giant2 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/train/vitl14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 32 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_large 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/configs/train/vitl16_short.yaml: -------------------------------------------------------------------------------- 1 | # this corresponds to the default config 2 | train: 3 | dataset_path: ImageNet:split=TRAIN 4 | batch_size_per_gpu: 64 5 | student: 6 | block_chunks: 4 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .adapters import DatasetWithEnumeratedTargets 7 | from .loaders import make_data_loader, make_dataset, SamplerType 8 | from .collate import collate_data_and_cast 9 | from .masking import MaskingGenerator 10 | from .augmentations import DataAugmentationDINO 11 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/adapters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Tuple 7 | 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class DatasetWithEnumeratedTargets(Dataset): 12 | def __init__(self, dataset): 13 | self._dataset = dataset 14 | 15 | def get_image_data(self, index: int) -> bytes: 16 | return self._dataset.get_image_data(index) 17 | 18 | def get_target(self, index: int) -> Tuple[Any, int]: 19 | target = self._dataset.get_target(index) 20 | return (index, target) 21 | 22 | def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: 23 | image, target = self._dataset[index] 24 | target = index if target is None else target 25 | return image, (index, target) 26 | 27 | def __len__(self) -> int: 28 | return len(self._dataset) 29 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from torchvision import transforms 9 | 10 | from .transforms import ( 11 | GaussianBlur, 12 | make_normalize_transform, 13 | ) 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class DataAugmentationDINO(object): 20 | def __init__( 21 | self, 22 | global_crops_scale, 23 | local_crops_scale, 24 | local_crops_number, 25 | global_crops_size=224, 26 | local_crops_size=96, 27 | ): 28 | self.global_crops_scale = global_crops_scale 29 | self.local_crops_scale = local_crops_scale 30 | self.local_crops_number = local_crops_number 31 | self.global_crops_size = global_crops_size 32 | self.local_crops_size = local_crops_size 33 | 34 | logger.info("###################################") 35 | logger.info("Using data augmentation parameters:") 36 | logger.info(f"global_crops_scale: {global_crops_scale}") 37 | logger.info(f"local_crops_scale: {local_crops_scale}") 38 | logger.info(f"local_crops_number: {local_crops_number}") 39 | logger.info(f"global_crops_size: {global_crops_size}") 40 | logger.info(f"local_crops_size: {local_crops_size}") 41 | logger.info("###################################") 42 | 43 | # random resized crop and flip 44 | self.geometric_augmentation_global = transforms.Compose( 45 | [ 46 | transforms.RandomResizedCrop( 47 | global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 48 | ), 49 | transforms.RandomHorizontalFlip(p=0.5), 50 | ] 51 | ) 52 | 53 | self.geometric_augmentation_local = transforms.Compose( 54 | [ 55 | transforms.RandomResizedCrop( 56 | local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 57 | ), 58 | transforms.RandomHorizontalFlip(p=0.5), 59 | ] 60 | ) 61 | 62 | # color distorsions / blurring 63 | color_jittering = transforms.Compose( 64 | [ 65 | transforms.RandomApply( 66 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 67 | p=0.8, 68 | ), 69 | transforms.RandomGrayscale(p=0.2), 70 | ] 71 | ) 72 | 73 | global_transfo1_extra = GaussianBlur(p=1.0) 74 | 75 | global_transfo2_extra = transforms.Compose( 76 | [ 77 | GaussianBlur(p=0.1), 78 | transforms.RandomSolarize(threshold=128, p=0.2), 79 | ] 80 | ) 81 | 82 | local_transfo_extra = GaussianBlur(p=0.5) 83 | 84 | # normalization 85 | self.normalize = transforms.Compose( 86 | [ 87 | transforms.ToTensor(), 88 | make_normalize_transform(), 89 | ] 90 | ) 91 | 92 | self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) 93 | self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) 94 | self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) 95 | 96 | def __call__(self, image): 97 | output = {} 98 | 99 | # global crops: 100 | im1_base = self.geometric_augmentation_global(image) 101 | global_crop_1 = self.global_transfo1(im1_base) 102 | 103 | im2_base = self.geometric_augmentation_global(image) 104 | global_crop_2 = self.global_transfo2(im2_base) 105 | 106 | output["global_crops"] = [global_crop_1, global_crop_2] 107 | 108 | # global crops for teacher: 109 | output["global_crops_teacher"] = [global_crop_1, global_crop_2] 110 | 111 | # local crops: 112 | local_crops = [ 113 | self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) 114 | ] 115 | output["local_crops"] = local_crops 116 | output["offsets"] = () 117 | 118 | return output 119 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import random 8 | 9 | 10 | def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): 11 | # dtype = torch.half # TODO: Remove 12 | 13 | n_global_crops = len(samples_list[0][0]["global_crops"]) 14 | n_local_crops = len(samples_list[0][0]["local_crops"]) 15 | 16 | collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) 17 | 18 | collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) 19 | 20 | B = len(collated_global_crops) 21 | N = n_tokens 22 | n_samples_masked = int(B * mask_probability) 23 | probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) 24 | upperbound = 0 25 | masks_list = [] 26 | for i in range(0, n_samples_masked): 27 | prob_min = probs[i] 28 | prob_max = probs[i + 1] 29 | masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) 30 | upperbound += int(N * prob_max) 31 | for i in range(n_samples_masked, B): 32 | masks_list.append(torch.BoolTensor(mask_generator(0))) 33 | 34 | random.shuffle(masks_list) 35 | 36 | collated_masks = torch.stack(masks_list).flatten(1) 37 | mask_indices_list = collated_masks.flatten().nonzero().flatten() 38 | 39 | masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] 40 | 41 | return { 42 | "collated_global_crops": collated_global_crops.to(dtype), 43 | "collated_local_crops": collated_local_crops.to(dtype), 44 | "collated_masks": collated_masks, 45 | "mask_indices_list": mask_indices_list, 46 | "masks_weight": masks_weight, 47 | "upperbound": upperbound, 48 | "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), 49 | } 50 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .image_net import ImageNet 7 | from .image_net_22k import ImageNet22k 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/datasets/decoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from io import BytesIO 7 | from typing import Any 8 | 9 | from PIL import Image 10 | 11 | 12 | class Decoder: 13 | def decode(self) -> Any: 14 | raise NotImplementedError 15 | 16 | 17 | class ImageDataDecoder(Decoder): 18 | def __init__(self, image_data: bytes) -> None: 19 | self._image_data = image_data 20 | 21 | def decode(self) -> Image: 22 | f = BytesIO(self._image_data) 23 | return Image.open(f).convert(mode="RGB") 24 | 25 | class TargetDecoder(Decoder): 26 | def __init__(self, target: Any): 27 | self._target = target 28 | 29 | def decode(self) -> Any: 30 | return self._target 31 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/datasets/extended.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Tuple 7 | 8 | from torchvision.datasets import VisionDataset 9 | 10 | from .decoders import TargetDecoder, ImageDataDecoder 11 | 12 | 13 | class ExtendedVisionDataset(VisionDataset): 14 | def __init__(self, *args, **kwargs) -> None: 15 | super().__init__(*args, **kwargs) # type: ignore 16 | 17 | def get_image_data(self, index: int) -> bytes: 18 | raise NotImplementedError 19 | 20 | def get_target(self, index: int) -> Any: 21 | raise NotImplementedError 22 | 23 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 24 | try: 25 | image_data = self.get_image_data(index) 26 | image = ImageDataDecoder(image_data).decode() 27 | except Exception as e: 28 | raise RuntimeError(f"can not read image for sample {index}") from e 29 | target = self.get_target(index) 30 | target = TargetDecoder(target).decode() 31 | 32 | if self.transforms is not None: 33 | image, target = self.transforms(image, target) 34 | 35 | return image, target 36 | 37 | def __len__(self) -> int: 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/masking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | import math 8 | import numpy as np 9 | 10 | 11 | class MaskingGenerator: 12 | def __init__( 13 | self, 14 | input_size, 15 | num_masking_patches=None, 16 | min_num_patches=4, 17 | max_num_patches=None, 18 | min_aspect=0.3, 19 | max_aspect=None, 20 | ): 21 | if not isinstance(input_size, tuple): 22 | input_size = (input_size,) * 2 23 | self.height, self.width = input_size 24 | 25 | self.num_patches = self.height * self.width 26 | self.num_masking_patches = num_masking_patches 27 | 28 | self.min_num_patches = min_num_patches 29 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 30 | 31 | max_aspect = max_aspect or 1 / min_aspect 32 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 33 | 34 | def __repr__(self): 35 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 36 | self.height, 37 | self.width, 38 | self.min_num_patches, 39 | self.max_num_patches, 40 | self.num_masking_patches, 41 | self.log_aspect_ratio[0], 42 | self.log_aspect_ratio[1], 43 | ) 44 | return repr_str 45 | 46 | def get_shape(self): 47 | return self.height, self.width 48 | 49 | def _mask(self, mask, max_mask_patches): 50 | delta = 0 51 | for _ in range(10): 52 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 53 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 54 | h = int(round(math.sqrt(target_area * aspect_ratio))) 55 | w = int(round(math.sqrt(target_area / aspect_ratio))) 56 | if w < self.width and h < self.height: 57 | top = random.randint(0, self.height - h) 58 | left = random.randint(0, self.width - w) 59 | 60 | num_masked = mask[top : top + h, left : left + w].sum() 61 | # Overlap 62 | if 0 < h * w - num_masked <= max_mask_patches: 63 | for i in range(top, top + h): 64 | for j in range(left, left + w): 65 | if mask[i, j] == 0: 66 | mask[i, j] = 1 67 | delta += 1 68 | 69 | if delta > 0: 70 | break 71 | return delta 72 | 73 | def __call__(self, num_masking_patches=0): 74 | mask = np.zeros(shape=self.get_shape(), dtype=bool) 75 | mask_count = 0 76 | while mask_count < num_masking_patches: 77 | max_mask_patches = num_masking_patches - mask_count 78 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 79 | 80 | delta = self._mask(mask, max_mask_patches) 81 | if delta == 0: 82 | break 83 | else: 84 | mask_count += delta 85 | 86 | return mask 87 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Sequence 7 | 8 | import torch 9 | from torchvision import transforms 10 | 11 | 12 | class GaussianBlur(transforms.RandomApply): 13 | """ 14 | Apply Gaussian Blur to the PIL image. 15 | """ 16 | 17 | def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): 18 | # NOTE: torchvision is applying 1 - probability to return the original image 19 | keep_p = 1 - p 20 | transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) 21 | super().__init__(transforms=[transform], p=keep_p) 22 | 23 | 24 | class MaybeToTensor(transforms.ToTensor): 25 | """ 26 | Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. 27 | """ 28 | 29 | def __call__(self, pic): 30 | """ 31 | Args: 32 | pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. 33 | Returns: 34 | Tensor: Converted image. 35 | """ 36 | if isinstance(pic, torch.Tensor): 37 | return pic 38 | return super().__call__(pic) 39 | 40 | 41 | # Use timm's names 42 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 43 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 44 | 45 | 46 | def make_normalize_transform( 47 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 48 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 49 | ) -> transforms.Normalize: 50 | return transforms.Normalize(mean=mean, std=std) 51 | 52 | 53 | # This roughly matches torchvision's preset for classification training: 54 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 55 | def make_classification_train_transform( 56 | *, 57 | crop_size: int = 224, 58 | interpolation=transforms.InterpolationMode.BICUBIC, 59 | hflip_prob: float = 0.5, 60 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 61 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 62 | ): 63 | transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 64 | if hflip_prob > 0.0: 65 | transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) 66 | transforms_list.extend( 67 | [ 68 | MaybeToTensor(), 69 | make_normalize_transform(mean=mean, std=std), 70 | ] 71 | ) 72 | return transforms.Compose(transforms_list) 73 | 74 | 75 | # This matches (roughly) torchvision's preset for classification evaluation: 76 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 77 | def make_classification_eval_transform( 78 | *, 79 | resize_size: int = 256, 80 | interpolation=transforms.InterpolationMode.BICUBIC, 81 | crop_size: int = 224, 82 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 83 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 84 | ) -> transforms.Compose: 85 | transforms_list = [ 86 | transforms.Resize(resize_size, interpolation=interpolation), 87 | transforms.CenterCrop(crop_size), 88 | MaybeToTensor(), 89 | make_normalize_transform(mean=mean, std=std), 90 | ] 91 | return transforms.Compose(transforms_list) 92 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .backbones import * # noqa: F403 7 | from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss 8 | from .decode_heads import * # noqa: F403 9 | from .depther import * # noqa: F403 10 | from .losses import * # noqa: F403 11 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .vision_transformer import DinoVisionTransformer 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/backbones/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from mmcv.runner import BaseModule 7 | 8 | from ..builder import BACKBONES 9 | 10 | 11 | @BACKBONES.register_module() 12 | class DinoVisionTransformer(BaseModule): 13 | """Vision Transformer.""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__() 17 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | from mmcv.cnn import MODELS as MMCV_MODELS 9 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 10 | from mmcv.utils import Registry 11 | 12 | MODELS = Registry("models", parent=MMCV_MODELS) 13 | ATTENTION = Registry("attention", parent=MMCV_ATTENTION) 14 | 15 | 16 | BACKBONES = MODELS 17 | NECKS = MODELS 18 | HEADS = MODELS 19 | LOSSES = MODELS 20 | DEPTHER = MODELS 21 | 22 | 23 | def build_backbone(cfg): 24 | """Build backbone.""" 25 | return BACKBONES.build(cfg) 26 | 27 | 28 | def build_neck(cfg): 29 | """Build neck.""" 30 | return NECKS.build(cfg) 31 | 32 | 33 | def build_head(cfg): 34 | """Build head.""" 35 | return HEADS.build(cfg) 36 | 37 | 38 | def build_loss(cfg): 39 | """Build loss.""" 40 | return LOSSES.build(cfg) 41 | 42 | 43 | def build_depther(cfg, train_cfg=None, test_cfg=None): 44 | """Build depther.""" 45 | if train_cfg is not None or test_cfg is not None: 46 | warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning) 47 | assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field " 48 | assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field " 49 | return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 50 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dpt_head import DPTHead 7 | from .linear_head import BNHead 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/decode_heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ...ops import resize 10 | from ..builder import HEADS 11 | from .decode_head import DepthBaseDecodeHead 12 | 13 | 14 | @HEADS.register_module() 15 | class BNHead(DepthBaseDecodeHead): 16 | """Just a batchnorm.""" 17 | 18 | def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): 19 | super().__init__(**kwargs) 20 | self.input_transform = input_transform 21 | self.in_index = in_index 22 | self.upsample = upsample 23 | # self.bn = nn.SyncBatchNorm(self.in_channels) 24 | if self.classify: 25 | self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) 26 | else: 27 | self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) 28 | 29 | def _transform_inputs(self, inputs): 30 | """Transform inputs for decoder. 31 | Args: 32 | inputs (list[Tensor]): List of multi-level img features. 33 | Returns: 34 | Tensor: The transformed inputs 35 | """ 36 | 37 | if "concat" in self.input_transform: 38 | inputs = [inputs[i] for i in self.in_index] 39 | if "resize" in self.input_transform: 40 | inputs = [ 41 | resize( 42 | input=x, 43 | size=[s * self.upsample for s in inputs[0].shape[2:]], 44 | mode="bilinear", 45 | align_corners=self.align_corners, 46 | ) 47 | for x in inputs 48 | ] 49 | inputs = torch.cat(inputs, dim=1) 50 | elif self.input_transform == "multiple_select": 51 | inputs = [inputs[i] for i in self.in_index] 52 | else: 53 | inputs = inputs[self.in_index] 54 | 55 | return inputs 56 | 57 | def _forward_feature(self, inputs, img_metas=None, **kwargs): 58 | """Forward function for feature maps before classifying each pixel with 59 | ``self.cls_seg`` fc. 60 | Args: 61 | inputs (list[Tensor]): List of multi-level img features. 62 | Returns: 63 | feats (Tensor): A tensor of shape (batch_size, self.channels, 64 | H, W) which is feature map for last layer of decoder head. 65 | """ 66 | # accept lists (for cls token) 67 | inputs = list(inputs) 68 | for i, x in enumerate(inputs): 69 | if len(x) == 2: 70 | x, cls_token = x[0], x[1] 71 | if len(x.shape) == 2: 72 | x = x[:, :, None, None] 73 | cls_token = cls_token[:, :, None, None].expand_as(x) 74 | inputs[i] = torch.cat((x, cls_token), 1) 75 | else: 76 | x = x[0] 77 | if len(x.shape) == 2: 78 | x = x[:, :, None, None] 79 | inputs[i] = x 80 | x = self._transform_inputs(inputs) 81 | # feats = self.bn(x) 82 | return x 83 | 84 | def forward(self, inputs, img_metas=None, **kwargs): 85 | """Forward function.""" 86 | output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) 87 | output = self.depth_pred(output) 88 | 89 | return output 90 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/depther/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .base import BaseDepther 7 | from .encoder_decoder import DepthEncoderDecoder 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .gradientloss import GradientLoss 7 | from .sigloss import SigLoss 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/losses/gradientloss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ...models.builder import LOSSES 10 | 11 | 12 | @LOSSES.register_module() 13 | class GradientLoss(nn.Module): 14 | """GradientLoss. 15 | 16 | Adapted from https://www.cs.cornell.edu/projects/megadepth/ 17 | 18 | Args: 19 | valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. 20 | loss_weight (float): Weight of the loss. Default: 1.0. 21 | max_depth (int): When filtering invalid gt, set a max threshold. Default: None. 22 | """ 23 | 24 | def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"): 25 | super(GradientLoss, self).__init__() 26 | self.valid_mask = valid_mask 27 | self.loss_weight = loss_weight 28 | self.max_depth = max_depth 29 | self.loss_name = loss_name 30 | 31 | self.eps = 0.001 # avoid grad explode 32 | 33 | def gradientloss(self, input, target): 34 | input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)] 35 | target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)] 36 | 37 | gradient_loss = 0 38 | for input, target in zip(input_downscaled, target_downscaled): 39 | if self.valid_mask: 40 | mask = target > 0 41 | if self.max_depth is not None: 42 | mask = torch.logical_and(target > 0, target <= self.max_depth) 43 | N = torch.sum(mask) 44 | else: 45 | mask = torch.ones_like(target) 46 | N = input.numel() 47 | input_log = torch.log(input + self.eps) 48 | target_log = torch.log(target + self.eps) 49 | log_d_diff = input_log - target_log 50 | 51 | log_d_diff = torch.mul(log_d_diff, mask) 52 | 53 | v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :]) 54 | v_mask = torch.mul(mask[0:-2, :], mask[2:, :]) 55 | v_gradient = torch.mul(v_gradient, v_mask) 56 | 57 | h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:]) 58 | h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:]) 59 | h_gradient = torch.mul(h_gradient, h_mask) 60 | 61 | gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N 62 | 63 | return gradient_loss 64 | 65 | def forward(self, depth_pred, depth_gt): 66 | """Forward function.""" 67 | 68 | gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt) 69 | return gradient_loss 70 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/models/losses/sigloss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ...models.builder import LOSSES 10 | 11 | 12 | @LOSSES.register_module() 13 | class SigLoss(nn.Module): 14 | """SigLoss. 15 | 16 | This follows `AdaBins `_. 17 | 18 | Args: 19 | valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. 20 | loss_weight (float): Weight of the loss. Default: 1.0. 21 | max_depth (int): When filtering invalid gt, set a max threshold. Default: None. 22 | warm_up (bool): A simple warm up stage to help convergence. Default: False. 23 | warm_iter (int): The number of warm up stage. Default: 100. 24 | """ 25 | 26 | def __init__( 27 | self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" 28 | ): 29 | super(SigLoss, self).__init__() 30 | self.valid_mask = valid_mask 31 | self.loss_weight = loss_weight 32 | self.max_depth = max_depth 33 | self.loss_name = loss_name 34 | 35 | self.eps = 0.001 # avoid grad explode 36 | 37 | # HACK: a hack implementation for warmup sigloss 38 | self.warm_up = warm_up 39 | self.warm_iter = warm_iter 40 | self.warm_up_counter = 0 41 | 42 | def sigloss(self, input, target): 43 | if self.valid_mask: 44 | valid_mask = target > 0 45 | if self.max_depth is not None: 46 | valid_mask = torch.logical_and(target > 0, target <= self.max_depth) 47 | input = input[valid_mask] 48 | target = target[valid_mask] 49 | 50 | if self.warm_up: 51 | if self.warm_up_counter < self.warm_iter: 52 | g = torch.log(input + self.eps) - torch.log(target + self.eps) 53 | g = 0.15 * torch.pow(torch.mean(g), 2) 54 | self.warm_up_counter += 1 55 | return torch.sqrt(g) 56 | 57 | g = torch.log(input + self.eps) - torch.log(target + self.eps) 58 | Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) 59 | return torch.sqrt(Dg) 60 | 61 | def forward(self, depth_pred, depth_gt): 62 | """Forward function.""" 63 | 64 | loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) 65 | return loss_depth 66 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .wrappers import resize 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/depth/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ( 18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 19 | and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1) 21 | ): 22 | warnings.warn( 23 | f"When align_corners={align_corners}, " 24 | "the output would more aligned if " 25 | f"input size {(input_h, input_w)} is `x+1` and " 26 | f"out size {(output_h, output_w)} is `nx+1`" 27 | ) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | import logging 8 | from typing import Any, Dict, Optional 9 | 10 | import torch 11 | from torch import Tensor 12 | from torchmetrics import Metric, MetricCollection 13 | from torchmetrics.classification import MulticlassAccuracy 14 | from torchmetrics.utilities.data import dim_zero_cat, select_topk 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | class MetricType(Enum): 21 | MEAN_ACCURACY = "mean_accuracy" 22 | MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" 23 | PER_CLASS_ACCURACY = "per_class_accuracy" 24 | IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" 25 | 26 | @property 27 | def accuracy_averaging(self): 28 | return getattr(AccuracyAveraging, self.name, None) 29 | 30 | def __str__(self): 31 | return self.value 32 | 33 | 34 | class AccuracyAveraging(Enum): 35 | MEAN_ACCURACY = "micro" 36 | MEAN_PER_CLASS_ACCURACY = "macro" 37 | PER_CLASS_ACCURACY = "none" 38 | 39 | def __str__(self): 40 | return self.value 41 | 42 | 43 | def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): 44 | if metric_type.accuracy_averaging is not None: 45 | return build_topk_accuracy_metric( 46 | average_type=metric_type.accuracy_averaging, 47 | num_classes=num_classes, 48 | ks=(1, 5) if ks is None else ks, 49 | ) 50 | elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: 51 | return build_topk_imagenet_real_accuracy_metric( 52 | num_classes=num_classes, 53 | ks=(1, 5) if ks is None else ks, 54 | ) 55 | 56 | raise ValueError(f"Unknown metric type {metric_type}") 57 | 58 | 59 | def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): 60 | metrics: Dict[str, Metric] = { 61 | f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks 62 | } 63 | return MetricCollection(metrics) 64 | 65 | 66 | def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): 67 | metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} 68 | return MetricCollection(metrics) 69 | 70 | 71 | class ImageNetReaLAccuracy(Metric): 72 | is_differentiable: bool = False 73 | higher_is_better: Optional[bool] = None 74 | full_state_update: bool = False 75 | 76 | def __init__( 77 | self, 78 | num_classes: int, 79 | top_k: int = 1, 80 | **kwargs: Any, 81 | ) -> None: 82 | super().__init__(**kwargs) 83 | self.num_classes = num_classes 84 | self.top_k = top_k 85 | self.add_state("tp", [], dist_reduce_fx="cat") 86 | 87 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 88 | # preds [B, D] 89 | # target [B, A] 90 | # preds_oh [B, D] with 0 and 1 91 | # select top K highest probabilities, use one hot representation 92 | preds_oh = select_topk(preds, self.top_k) 93 | # target_oh [B, D + 1] with 0 and 1 94 | target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) 95 | target = target.long() 96 | # for undefined targets (-1) use a fake value `num_classes` 97 | target[target == -1] = self.num_classes 98 | # fill targets, use one hot representation 99 | target_oh.scatter_(1, target, 1) 100 | # target_oh [B, D] (remove the fake target at index `num_classes`) 101 | target_oh = target_oh[:, :-1] 102 | # tp [B] with 0 and 1 103 | tp = (preds_oh * target_oh == 1).sum(dim=1) 104 | # at least one match between prediction and target 105 | tp.clip_(max=1) 106 | # ignore instances where no targets are defined 107 | mask = target_oh.sum(dim=1) > 0 108 | tp = tp[mask] 109 | self.tp.append(tp) # type: ignore 110 | 111 | def compute(self) -> Tensor: 112 | tp = dim_zero_cat(self.tp) # type: ignore 113 | return tp.float().mean() 114 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .optimizer import DistOptimizerHook 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/hooks/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | try: 7 | import apex 8 | except ImportError: 9 | print("apex is not installed") 10 | 11 | from mmcv.runner import OptimizerHook, HOOKS 12 | 13 | 14 | @HOOKS.register_module() 15 | class DistOptimizerHook(OptimizerHook): 16 | """Optimizer hook for distributed training.""" 17 | 18 | def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): 19 | self.grad_clip = grad_clip 20 | self.coalesce = coalesce 21 | self.bucket_size_mb = bucket_size_mb 22 | self.update_interval = update_interval 23 | self.use_fp16 = use_fp16 24 | 25 | def before_run(self, runner): 26 | runner.optimizer.zero_grad() 27 | 28 | def after_train_iter(self, runner): 29 | runner.outputs["loss"] /= self.update_interval 30 | if self.use_fp16: 31 | # runner.outputs['loss'].backward() 32 | with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: 33 | scaled_loss.backward() 34 | else: 35 | runner.outputs["loss"].backward() 36 | if self.every_n_iters(runner, self.update_interval): 37 | if self.grad_clip is not None: 38 | self.clip_grads(runner.model.parameters()) 39 | runner.optimizer.step() 40 | runner.optimizer.zero_grad() 41 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .backbones import * # noqa: F403 7 | from .decode_heads import * # noqa: F403 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .vision_transformer import DinoVisionTransformer 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/models/backbones/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from mmcv.runner import BaseModule 7 | from mmseg.models.builder import BACKBONES 8 | 9 | 10 | @BACKBONES.register_module() 11 | class DinoVisionTransformer(BaseModule): 12 | """Vision Transformer.""" 13 | 14 | def __init__( 15 | self, 16 | *args, 17 | **kwargs, 18 | ): 19 | super().__init__() 20 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .linear_head import BNHead 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/models/decode_heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from mmseg.models.builder import HEADS 10 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead 11 | from mmseg.ops import resize 12 | 13 | 14 | @HEADS.register_module() 15 | class BNHead(BaseDecodeHead): 16 | """Just a batchnorm.""" 17 | 18 | def __init__(self, resize_factors=None, **kwargs): 19 | super().__init__(**kwargs) 20 | assert self.in_channels == self.channels 21 | self.bn = nn.SyncBatchNorm(self.in_channels) 22 | self.resize_factors = resize_factors 23 | 24 | def _forward_feature(self, inputs): 25 | """Forward function for feature maps before classifying each pixel with 26 | ``self.cls_seg`` fc. 27 | 28 | Args: 29 | inputs (list[Tensor]): List of multi-level img features. 30 | 31 | Returns: 32 | feats (Tensor): A tensor of shape (batch_size, self.channels, 33 | H, W) which is feature map for last layer of decoder head. 34 | """ 35 | # print("inputs", [i.shape for i in inputs]) 36 | x = self._transform_inputs(inputs) 37 | # print("x", x.shape) 38 | feats = self.bn(x) 39 | # print("feats", feats.shape) 40 | return feats 41 | 42 | def _transform_inputs(self, inputs): 43 | """Transform inputs for decoder. 44 | Args: 45 | inputs (list[Tensor]): List of multi-level img features. 46 | Returns: 47 | Tensor: The transformed inputs 48 | """ 49 | 50 | if self.input_transform == "resize_concat": 51 | # accept lists (for cls token) 52 | input_list = [] 53 | for x in inputs: 54 | if isinstance(x, list): 55 | input_list.extend(x) 56 | else: 57 | input_list.append(x) 58 | inputs = input_list 59 | # an image descriptor can be a local descriptor with resolution 1x1 60 | for i, x in enumerate(inputs): 61 | if len(x.shape) == 2: 62 | inputs[i] = x[:, :, None, None] 63 | # select indices 64 | inputs = [inputs[i] for i in self.in_index] 65 | # Resizing shenanigans 66 | # print("before", *(x.shape for x in inputs)) 67 | if self.resize_factors is not None: 68 | assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs)) 69 | inputs = [ 70 | resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area") 71 | for x, f in zip(inputs, self.resize_factors) 72 | ] 73 | # print("after", *(x.shape for x in inputs)) 74 | upsampled_inputs = [ 75 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) 76 | for x in inputs 77 | ] 78 | inputs = torch.cat(upsampled_inputs, dim=1) 79 | elif self.input_transform == "multiple_select": 80 | inputs = [inputs[i] for i in self.in_index] 81 | else: 82 | inputs = inputs[self.in_index] 83 | 84 | return inputs 85 | 86 | def forward(self, inputs): 87 | """Forward function.""" 88 | output = self._forward_feature(inputs) 89 | output = self.cls_seg(output) 90 | return output 91 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .core import * # noqa: F403 7 | from .models import * # noqa: F403 8 | from .ops import * # noqa: F403 9 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from mmseg.core.evaluation import * # noqa: F403 7 | from mmseg.core.seg import * # noqa: F403 8 | 9 | from .anchor import * # noqa: F403 10 | from .box import * # noqa: F403 11 | from .utils import * # noqa: F403 12 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .point_generator import MlvlPointGenerator # noqa: F403 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/anchor/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | from mmcv.utils import Registry, build_from_cfg 9 | 10 | PRIOR_GENERATORS = Registry("Generator for anchors and points") 11 | 12 | ANCHOR_GENERATORS = PRIOR_GENERATORS 13 | 14 | 15 | def build_prior_generator(cfg, default_args=None): 16 | return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) 17 | 18 | 19 | def build_anchor_generator(cfg, default_args=None): 20 | warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") 21 | return build_prior_generator(cfg, default_args=default_args) 22 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .builder import * # noqa: F403 7 | from .samplers import MaskPseudoSampler # noqa: F403 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from mmcv.utils import Registry, build_from_cfg 7 | 8 | BBOX_SAMPLERS = Registry("bbox_sampler") 9 | BBOX_CODERS = Registry("bbox_coder") 10 | 11 | 12 | def build_sampler(cfg, **default_args): 13 | """Builder of box sampler.""" 14 | return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) 15 | 16 | 17 | def build_bbox_coder(cfg, **default_args): 18 | """Builder of box coder.""" 19 | return build_from_cfg(cfg, BBOX_CODERS, default_args) 20 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from abc import ABCMeta, abstractmethod 7 | 8 | import torch 9 | 10 | from .sampling_result import SamplingResult 11 | 12 | 13 | class BaseSampler(metaclass=ABCMeta): 14 | """Base class of samplers.""" 15 | 16 | def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): 17 | self.num = num 18 | self.pos_fraction = pos_fraction 19 | self.neg_pos_ub = neg_pos_ub 20 | self.add_gt_as_proposals = add_gt_as_proposals 21 | self.pos_sampler = self 22 | self.neg_sampler = self 23 | 24 | @abstractmethod 25 | def _sample_pos(self, assign_result, num_expected, **kwargs): 26 | """Sample positive samples.""" 27 | pass 28 | 29 | @abstractmethod 30 | def _sample_neg(self, assign_result, num_expected, **kwargs): 31 | """Sample negative samples.""" 32 | pass 33 | 34 | def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): 35 | """Sample positive and negative bboxes. 36 | 37 | This is a simple implementation of bbox sampling given candidates, 38 | assigning results and ground truth bboxes. 39 | 40 | Args: 41 | assign_result (:obj:`AssignResult`): Bbox assigning results. 42 | bboxes (Tensor): Boxes to be sampled from. 43 | gt_bboxes (Tensor): Ground truth bboxes. 44 | gt_labels (Tensor, optional): Class labels of ground truth bboxes. 45 | 46 | Returns: 47 | :obj:`SamplingResult`: Sampling result. 48 | 49 | Example: 50 | >>> from mmdet.core.bbox import RandomSampler 51 | >>> from mmdet.core.bbox import AssignResult 52 | >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes 53 | >>> rng = ensure_rng(None) 54 | >>> assign_result = AssignResult.random(rng=rng) 55 | >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) 56 | >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) 57 | >>> gt_labels = None 58 | >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, 59 | >>> add_gt_as_proposals=False) 60 | >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) 61 | """ 62 | if len(bboxes.shape) < 2: 63 | bboxes = bboxes[None, :] 64 | 65 | bboxes = bboxes[:, :4] 66 | 67 | gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) 68 | if self.add_gt_as_proposals and len(gt_bboxes) > 0: 69 | if gt_labels is None: 70 | raise ValueError("gt_labels must be given when add_gt_as_proposals is True") 71 | bboxes = torch.cat([gt_bboxes, bboxes], dim=0) 72 | assign_result.add_gt_(gt_labels) 73 | gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) 74 | gt_flags = torch.cat([gt_ones, gt_flags]) 75 | 76 | num_expected_pos = int(self.num * self.pos_fraction) 77 | pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) 78 | # We found that sampled indices have duplicated items occasionally. 79 | # (may be a bug of PyTorch) 80 | pos_inds = pos_inds.unique() 81 | num_sampled_pos = pos_inds.numel() 82 | num_expected_neg = self.num - num_sampled_pos 83 | if self.neg_pos_ub >= 0: 84 | _pos = max(1, num_sampled_pos) 85 | neg_upper_bound = int(self.neg_pos_ub * _pos) 86 | if num_expected_neg > neg_upper_bound: 87 | num_expected_neg = neg_upper_bound 88 | neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) 89 | neg_inds = neg_inds.unique() 90 | 91 | sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) 92 | return sampling_result 93 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py 8 | 9 | import torch 10 | 11 | from ..builder import BBOX_SAMPLERS 12 | from .base_sampler import BaseSampler 13 | from .mask_sampling_result import MaskSamplingResult 14 | 15 | 16 | @BBOX_SAMPLERS.register_module() 17 | class MaskPseudoSampler(BaseSampler): 18 | """A pseudo sampler that does not do sampling actually.""" 19 | 20 | def __init__(self, **kwargs): 21 | pass 22 | 23 | def _sample_pos(self, **kwargs): 24 | """Sample positive samples.""" 25 | raise NotImplementedError 26 | 27 | def _sample_neg(self, **kwargs): 28 | """Sample negative samples.""" 29 | raise NotImplementedError 30 | 31 | def sample(self, assign_result, masks, gt_masks, **kwargs): 32 | """Directly returns the positive and negative indices of samples. 33 | 34 | Args: 35 | assign_result (:obj:`AssignResult`): Assigned results 36 | masks (torch.Tensor): Bounding boxes 37 | gt_masks (torch.Tensor): Ground truth boxes 38 | Returns: 39 | :obj:`SamplingResult`: sampler results 40 | """ 41 | pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() 42 | neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() 43 | gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) 44 | sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) 45 | return sampling_result 46 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py 8 | 9 | import torch 10 | 11 | from .sampling_result import SamplingResult 12 | 13 | 14 | class MaskSamplingResult(SamplingResult): 15 | """Mask sampling result.""" 16 | 17 | def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): 18 | self.pos_inds = pos_inds 19 | self.neg_inds = neg_inds 20 | self.pos_masks = masks[pos_inds] 21 | self.neg_masks = masks[neg_inds] 22 | self.pos_is_gt = gt_flags[pos_inds] 23 | 24 | self.num_gts = gt_masks.shape[0] 25 | self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 26 | 27 | if gt_masks.numel() == 0: 28 | # hack for index error case 29 | assert self.pos_assigned_gt_inds.numel() == 0 30 | self.pos_gt_masks = torch.empty_like(gt_masks) 31 | else: 32 | self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] 33 | 34 | if assign_result.labels is not None: 35 | self.pos_gt_labels = assign_result.labels[pos_inds] 36 | else: 37 | self.pos_gt_labels = None 38 | 39 | @property 40 | def masks(self): 41 | """torch.Tensor: concatenated positive and negative boxes""" 42 | return torch.cat([self.pos_masks, self.neg_masks]) 43 | 44 | def __nice__(self): 45 | data = self.info.copy() 46 | data["pos_masks"] = data.pop("pos_masks").shape 47 | data["neg_masks"] = data.pop("neg_masks").shape 48 | parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] 49 | body = " " + ",\n ".join(parts) 50 | return "{\n" + body + "\n}" 51 | 52 | @property 53 | def info(self): 54 | """Returns a dictionary of info about the object.""" 55 | return { 56 | "pos_inds": self.pos_inds, 57 | "neg_inds": self.neg_inds, 58 | "pos_masks": self.pos_masks, 59 | "neg_masks": self.neg_masks, 60 | "pos_is_gt": self.pos_is_gt, 61 | "num_gts": self.num_gts, 62 | "pos_assigned_gt_inds": self.pos_assigned_gt_inds, 63 | } 64 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dist_utils import reduce_mean 7 | from .misc import add_prefix, multi_apply 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch.distributed as dist 7 | 8 | 9 | def reduce_mean(tensor): 10 | """ "Obtain the mean of tensor on different GPUs.""" 11 | if not (dist.is_available() and dist.is_initialized()): 12 | return tensor 13 | tensor = tensor.clone() 14 | dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) 15 | return tensor 16 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from functools import partial 7 | 8 | 9 | def multi_apply(func, *args, **kwargs): 10 | """Apply function to a list of arguments. 11 | 12 | Note: 13 | This function applies the ``func`` to multiple inputs and 14 | map the multiple outputs of the ``func`` into different 15 | list. Each list contains the same type of outputs corresponding 16 | to different inputs. 17 | 18 | Args: 19 | func (Function): A function that will be applied to a list of 20 | arguments 21 | 22 | Returns: 23 | tuple(list): A tuple containing multiple list, each list contains \ 24 | a kind of returned results by the function 25 | """ 26 | pfunc = partial(func, **kwargs) if kwargs else func 27 | map_results = map(pfunc, *args) 28 | return tuple(map(list, zip(*map_results))) 29 | 30 | 31 | def add_prefix(inputs, prefix): 32 | """Add prefix for dict. 33 | 34 | Args: 35 | inputs (dict): The input dict with str keys. 36 | prefix (str): The prefix to add. 37 | 38 | Returns: 39 | 40 | dict: The dict with keys updated with ``prefix``. 41 | """ 42 | 43 | outputs = dict() 44 | for name, value in inputs.items(): 45 | outputs[f"{prefix}.{name}"] = value 46 | 47 | return outputs 48 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .backbones import * # noqa: F403 7 | from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost 8 | from .decode_heads import * # noqa: F403 9 | from .losses import * # noqa: F403 10 | from .plugins import * # noqa: F403 11 | from .segmentors import * # noqa: F403 12 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .vit_adapter import ViTAdapter 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | from torch import nn 11 | 12 | 13 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 14 | if drop_prob == 0.0 or not training: 15 | return x 16 | keep_prob = 1 - drop_prob 17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 18 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 19 | if keep_prob > 0.0: 20 | random_tensor.div_(keep_prob) 21 | return x * random_tensor 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 26 | 27 | def __init__(self, drop_prob: float = 0.0): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | 31 | def forward(self, x): 32 | return drop_path(x, self.drop_prob, self.training) 33 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from mmcv.utils import Registry 7 | 8 | TRANSFORMER = Registry("Transformer") 9 | MASK_ASSIGNERS = Registry("mask_assigner") 10 | MATCH_COST = Registry("match_cost") 11 | 12 | 13 | def build_match_cost(cfg): 14 | """Build Match Cost.""" 15 | return MATCH_COST.build(cfg) 16 | 17 | 18 | def build_assigner(cfg): 19 | """Build Assigner.""" 20 | return MASK_ASSIGNERS.build(cfg) 21 | 22 | 23 | def build_transformer(cfg): 24 | """Build Transformer.""" 25 | return TRANSFORMER.build(cfg) 26 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .mask2former_head import Mask2FormerHead 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy 7 | from .dice_loss import DiceLoss 8 | from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost 9 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .encoder_decoder_mask2former import EncoderDecoderMask2Former 7 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .assigner import MaskHungarianAssigner 7 | from .point_sample import get_uncertain_point_coords_with_randomness 8 | from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding 9 | from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer 10 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from mmcv.ops import point_sample 8 | 9 | 10 | def get_uncertainty(mask_pred, labels): 11 | """Estimate uncertainty based on pred logits. 12 | 13 | We estimate uncertainty as L1 distance between 0.0 and the logits 14 | prediction in 'mask_pred' for the foreground class in `classes`. 15 | 16 | Args: 17 | mask_pred (Tensor): mask predication logits, shape (num_rois, 18 | num_classes, mask_height, mask_width). 19 | 20 | labels (list[Tensor]): Either predicted or ground truth label for 21 | each predicted mask, of length num_rois. 22 | 23 | Returns: 24 | scores (Tensor): Uncertainty scores with the most uncertain 25 | locations having the highest uncertainty score, 26 | shape (num_rois, 1, mask_height, mask_width) 27 | """ 28 | if mask_pred.shape[1] == 1: 29 | gt_class_logits = mask_pred.clone() 30 | else: 31 | inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) 32 | gt_class_logits = mask_pred[inds, labels].unsqueeze(1) 33 | return -torch.abs(gt_class_logits) 34 | 35 | 36 | def get_uncertain_point_coords_with_randomness( 37 | mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio 38 | ): 39 | """Get ``num_points`` most uncertain points with random points during 40 | train. 41 | 42 | Sample points in [0, 1] x [0, 1] coordinate space based on their 43 | uncertainty. The uncertainties are calculated for each point using 44 | 'get_uncertainty()' function that takes point's logit prediction as 45 | input. 46 | 47 | Args: 48 | mask_pred (Tensor): A tensor of shape (num_rois, num_classes, 49 | mask_height, mask_width) for class-specific or class-agnostic 50 | prediction. 51 | labels (list): The ground truth class for each instance. 52 | num_points (int): The number of points to sample. 53 | oversample_ratio (int): Oversampling parameter. 54 | importance_sample_ratio (float): Ratio of points that are sampled 55 | via importnace sampling. 56 | 57 | Returns: 58 | point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) 59 | that contains the coordinates sampled points. 60 | """ 61 | assert oversample_ratio >= 1 62 | assert 0 <= importance_sample_ratio <= 1 63 | batch_size = mask_pred.shape[0] 64 | num_sampled = int(num_points * oversample_ratio) 65 | point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) 66 | point_logits = point_sample(mask_pred, point_coords) 67 | # It is crucial to calculate uncertainty based on the sampled 68 | # prediction value for the points. Calculating uncertainties of the 69 | # coarse predictions first and sampling them for points leads to 70 | # incorrect results. To illustrate this: assume uncertainty func( 71 | # logits)=-abs(logits), a sampled point between two coarse 72 | # predictions with -1 and 1 logits has 0 logits, and therefore 0 73 | # uncertainty value. However, if we calculate uncertainties for the 74 | # coarse predictions first, both will have -1 uncertainty, 75 | # and sampled point will get -1 uncertainty. 76 | point_uncertainties = get_uncertainty(point_logits, labels) 77 | num_uncertain_points = int(importance_sample_ratio * num_points) 78 | num_random_points = num_points - num_uncertain_points 79 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 80 | shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) 81 | idx += shift[:, None] 82 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) 83 | if num_random_points > 0: 84 | rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) 85 | point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) 86 | return point_coords 87 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules 8 | # https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | 10 | from .ms_deform_attn import MSDeformAttn 11 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/eval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from typing import Any, List, Optional, Tuple 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | from model_pack.dinoV2.dinov2.models import build_model_from_cfg 13 | from model_pack.dinoV2.dinov2.utils.config import setup 14 | import model_pack.dinoV2.dinov2.utils.utils as dinov2_utils 15 | 16 | 17 | def get_args_parser( 18 | description: Optional[str] = None, 19 | parents: Optional[List[argparse.ArgumentParser]] = None, 20 | add_help: bool = True, 21 | ): 22 | parser = argparse.ArgumentParser( 23 | description=description, 24 | parents=parents or [], 25 | add_help=add_help, 26 | ) 27 | parser.add_argument( 28 | "--config-file", 29 | type=str, 30 | help="Model configuration file", 31 | ) 32 | parser.add_argument( 33 | "--pretrained-weights", 34 | type=str, 35 | help="Pretrained model weights", 36 | ) 37 | parser.add_argument( 38 | "--output-dir", 39 | default="", 40 | type=str, 41 | help="Output directory to write results and logs", 42 | ) 43 | parser.add_argument( 44 | "--opts", 45 | help="Extra configuration options", 46 | default=[], 47 | nargs="+", 48 | ) 49 | return parser 50 | 51 | 52 | def get_autocast_dtype(config): 53 | teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype 54 | if teacher_dtype_str == "fp16": 55 | return torch.half 56 | elif teacher_dtype_str == "bf16": 57 | return torch.bfloat16 58 | else: 59 | return torch.float 60 | 61 | 62 | def build_model_for_eval(config, pretrained_weights): 63 | model, _ = build_model_from_cfg(config, only_teacher=True) 64 | # dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") 65 | # model.eval() 66 | # model.cuda() 67 | return model 68 | 69 | 70 | def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: 71 | cudnn.benchmark = True 72 | config = setup(args) 73 | model = build_model_for_eval(config, args.pretrained_weights) 74 | autocast_dtype = get_autocast_dtype(config) 75 | return model, autocast_dtype 76 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 12 | 13 | 14 | class Weights(Enum): 15 | LVD142M = "LVD142M" 16 | 17 | 18 | def _make_dinov2_model( 19 | *, 20 | arch_name: str = "vit_large", 21 | img_size: int = 518, 22 | patch_size: int = 14, 23 | init_values: float = 1.0, 24 | ffn_layer: str = "mlp", 25 | block_chunks: int = 0, 26 | num_register_tokens: int = 0, 27 | interpolate_antialias: bool = False, 28 | interpolate_offset: float = 0.1, 29 | pretrained: bool = True, 30 | weights: Union[Weights, str] = Weights.LVD142M, 31 | **kwargs, 32 | ): 33 | from ..models import vision_transformer as vits 34 | 35 | if isinstance(weights, str): 36 | try: 37 | weights = Weights[weights] 38 | except KeyError: 39 | raise AssertionError(f"Unsupported weights: {weights}") 40 | 41 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 42 | vit_kwargs = dict( 43 | img_size=img_size, 44 | patch_size=patch_size, 45 | init_values=init_values, 46 | ffn_layer=ffn_layer, 47 | block_chunks=block_chunks, 48 | num_register_tokens=num_register_tokens, 49 | interpolate_antialias=interpolate_antialias, 50 | interpolate_offset=interpolate_offset, 51 | ) 52 | vit_kwargs.update(**kwargs) 53 | model = vits.__dict__[arch_name](**vit_kwargs) 54 | 55 | if pretrained: 56 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 57 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 58 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 59 | model.load_state_dict(state_dict, strict=True) 60 | 61 | return model 62 | 63 | 64 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 65 | """ 66 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 67 | """ 68 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 69 | 70 | 71 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 72 | """ 73 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 74 | """ 75 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 76 | 77 | 78 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 79 | """ 80 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 81 | """ 82 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 83 | 84 | 85 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 86 | """ 87 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 88 | """ 89 | return _make_dinov2_model( 90 | arch_name="vit_giant2", 91 | ffn_layer="swiglufused", 92 | weights=weights, 93 | pretrained=pretrained, 94 | **kwargs, 95 | ) 96 | 97 | 98 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 99 | """ 100 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 101 | """ 102 | return _make_dinov2_model( 103 | arch_name="vit_small", 104 | pretrained=pretrained, 105 | weights=weights, 106 | num_register_tokens=4, 107 | interpolate_antialias=True, 108 | interpolate_offset=0.0, 109 | **kwargs, 110 | ) 111 | 112 | 113 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 114 | """ 115 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 116 | """ 117 | return _make_dinov2_model( 118 | arch_name="vit_base", 119 | pretrained=pretrained, 120 | weights=weights, 121 | num_register_tokens=4, 122 | interpolate_antialias=True, 123 | interpolate_offset=0.0, 124 | **kwargs, 125 | ) 126 | 127 | 128 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 129 | """ 130 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 131 | """ 132 | return _make_dinov2_model( 133 | arch_name="vit_large", 134 | pretrained=pretrained, 135 | weights=weights, 136 | num_register_tokens=4, 137 | interpolate_antialias=True, 138 | interpolate_offset=0.0, 139 | **kwargs, 140 | ) 141 | 142 | 143 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 144 | """ 145 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 146 | """ 147 | return _make_dinov2_model( 148 | arch_name="vit_giant2", 149 | ffn_layer="swiglufused", 150 | weights=weights, 151 | pretrained=pretrained, 152 | num_register_tokens=4, 153 | interpolate_antialias=True, 154 | interpolate_offset=0.0, 155 | **kwargs, 156 | ) 157 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/hub/depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .decode_heads import BNHead, DPTHead 7 | from .encoder_decoder import DepthEncoderDecoder 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/hub/depth/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ( 18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 19 | and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1) 21 | ): 22 | warnings.warn( 23 | f"When align_corners={align_corners}, " 24 | "the output would more aligned if " 25 | f"input size {(input_h, input_w)} is `x+1` and " 26 | f"out size {(output_h, output_w)} is `nx+1`" 27 | ) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | # warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) # (14,14); patch_size = 14 50 | 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | patch_grid_size_ori = ( 56 | image_HW[0] // 14, 57 | image_HW[1] // 14, 58 | ) 59 | 60 | self.img_size = image_HW 61 | self.patch_size = patch_HW 62 | self.patches_resolution = patch_grid_size 63 | self.num_patches = patch_grid_size_ori[0] * patch_grid_size_ori[1] 64 | 65 | self.in_chans = in_chans 66 | self.embed_dim = embed_dim 67 | 68 | self.flatten_embedding = flatten_embedding 69 | 70 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 71 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 72 | 73 | def forward(self, x: Tensor) -> Tensor: 74 | _, _, H, W = x.shape 75 | patch_H, patch_W = self.patch_size 76 | 77 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 78 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 79 | 80 | x = self.proj(x) # B C H W 81 | H, W = x.size(2), x.size(3) 82 | x = x.flatten(2).transpose(1, 2) # B HW C 83 | x = self.norm(x) 84 | if not self.flatten_embedding: 85 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 86 | return x 87 | 88 | def flops(self) -> float: 89 | Ho, Wo = self.patches_resolution 90 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 91 | if self.norm is not None: 92 | flops += Ho * Wo * self.embed_dim 93 | return flops 94 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 37 | try: 38 | if XFORMERS_ENABLED: 39 | from xformers.ops import SwiGLU 40 | 41 | XFORMERS_AVAILABLE = True 42 | # warnings.warn("xFormers is available (SwiGLU)") 43 | else: 44 | warnings.warn("xFormers is disabled (SwiGLU)") 45 | raise ImportError 46 | except ImportError: 47 | SwiGLU = SwiGLUFFN 48 | XFORMERS_AVAILABLE = False 49 | 50 | warnings.warn("xFormers is not available (SwiGLU)") 51 | 52 | 53 | class SwiGLUFFNFused(SwiGLU): 54 | def __init__( 55 | self, 56 | in_features: int, 57 | hidden_features: Optional[int] = None, 58 | out_features: Optional[int] = None, 59 | act_layer: Callable[..., nn.Module] = None, 60 | drop: float = 0.0, 61 | bias: bool = True, 62 | ) -> None: 63 | out_features = out_features or in_features 64 | hidden_features = hidden_features or in_features 65 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 66 | super().__init__( 67 | in_features=in_features, 68 | hidden_features=hidden_features, 69 | out_features=out_features, 70 | bias=bias, 71 | ) 72 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | import logging 8 | import os 9 | import sys 10 | from typing import Optional 11 | 12 | import model_pack.dinoV2.dinov2.distributed as distributed 13 | from .helpers import MetricLogger, SmoothedValue 14 | 15 | 16 | # So that calling _configure_logger multiple times won't add many handlers 17 | @functools.lru_cache() 18 | def _configure_logger( 19 | name: Optional[str] = None, 20 | *, 21 | level: int = logging.DEBUG, 22 | output: Optional[str] = None, 23 | ): 24 | """ 25 | Configure a logger. 26 | 27 | Adapted from Detectron2. 28 | 29 | Args: 30 | name: The name of the logger to configure. 31 | level: The logging level to use. 32 | output: A file name or a directory to save log. If None, will not save log file. 33 | If ends with ".txt" or ".log", assumed to be a file name. 34 | Otherwise, logs will be saved to `output/log.txt`. 35 | 36 | Returns: 37 | The configured logger. 38 | """ 39 | 40 | logger = logging.getLogger(name) 41 | logger.setLevel(level) 42 | logger.propagate = False 43 | 44 | # Loosely match Google glog format: 45 | # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg 46 | # but use a shorter timestamp and include the logger name: 47 | # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg 48 | fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " 49 | fmt_message = "%(message)s" 50 | fmt = fmt_prefix + fmt_message 51 | datefmt = "%Y%m%d %H:%M:%S" 52 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 53 | 54 | # stdout logging for main worker only 55 | if distributed.is_main_process(): 56 | handler = logging.StreamHandler(stream=sys.stdout) 57 | handler.setLevel(logging.DEBUG) 58 | handler.setFormatter(formatter) 59 | logger.addHandler(handler) 60 | 61 | # file logging for all workers 62 | if output: 63 | if os.path.splitext(output)[-1] in (".txt", ".log"): 64 | filename = output 65 | else: 66 | filename = os.path.join(output, "logs", "log.txt") 67 | 68 | if not distributed.is_main_process(): 69 | global_rank = distributed.get_global_rank() 70 | filename = filename + ".rank{}".format(global_rank) 71 | 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | 74 | handler = logging.StreamHandler(open(filename, "a")) 75 | handler.setLevel(logging.DEBUG) 76 | handler.setFormatter(formatter) 77 | logger.addHandler(handler) 78 | 79 | return logger 80 | 81 | 82 | def setup_logging( 83 | output: Optional[str] = None, 84 | *, 85 | name: Optional[str] = None, 86 | level: int = logging.DEBUG, 87 | capture_warnings: bool = True, 88 | ) -> None: 89 | """ 90 | Setup logging. 91 | 92 | Args: 93 | output: A file name or a directory to save log files. If None, log 94 | files will not be saved. If output ends with ".txt" or ".log", it 95 | is assumed to be a file name. 96 | Otherwise, logs will be saved to `output/log.txt`. 97 | name: The name of the logger to configure, by default the root logger. 98 | level: The logging level to use. 99 | capture_warnings: Whether warnings should be captured as logs. 100 | """ 101 | logging.captureWarnings(capture_warnings) 102 | _configure_logger(name, level=level, output=output) 103 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_clstoken_loss import DINOLoss 7 | from .ibot_patch_loss import iBOTPatchLoss 8 | from .koleo_loss import KoLeoLoss 9 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/loss/dino_clstoken_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | 12 | class DINOLoss(nn.Module): 13 | def __init__( 14 | self, 15 | out_dim, 16 | student_temp=0.1, 17 | center_momentum=0.9, 18 | ): 19 | super().__init__() 20 | self.student_temp = student_temp 21 | self.center_momentum = center_momentum 22 | self.register_buffer("center", torch.zeros(1, out_dim)) 23 | self.updated = True 24 | self.reduce_handle = None 25 | self.len_teacher_output = None 26 | self.async_batch_center = None 27 | 28 | @torch.no_grad() 29 | def softmax_center_teacher(self, teacher_output, teacher_temp): 30 | self.apply_center_update() 31 | # teacher centering and sharpening 32 | return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) 33 | 34 | @torch.no_grad() 35 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): 36 | teacher_output = teacher_output.float() 37 | world_size = dist.get_world_size() if dist.is_initialized() else 1 38 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper 39 | B = Q.shape[1] * world_size # number of samples to assign 40 | K = Q.shape[0] # how many prototypes 41 | 42 | # make the matrix sums to 1 43 | sum_Q = torch.sum(Q) 44 | if dist.is_initialized(): 45 | dist.all_reduce(sum_Q) 46 | Q /= sum_Q 47 | 48 | for it in range(n_iterations): 49 | # normalize each row: total weight per prototype must be 1/K 50 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 51 | if dist.is_initialized(): 52 | dist.all_reduce(sum_of_rows) 53 | Q /= sum_of_rows 54 | Q /= K 55 | 56 | # normalize each column: total weight per sample must be 1/B 57 | Q /= torch.sum(Q, dim=0, keepdim=True) 58 | Q /= B 59 | 60 | Q *= B # the columns must sum to 1 so that Q is an assignment 61 | return Q.t() 62 | 63 | def forward(self, student_output_list, teacher_out_softmaxed_centered_list): 64 | """ 65 | Cross-entropy between softmax outputs of the teacher and student networks. 66 | """ 67 | # TODO: Use cross_entropy_distribution here 68 | total_loss = 0 69 | for s in student_output_list: 70 | lsm = F.log_softmax(s / self.student_temp, dim=-1) 71 | for t in teacher_out_softmaxed_centered_list: 72 | loss = torch.sum(t * lsm, dim=-1) 73 | total_loss -= loss.mean() 74 | return total_loss 75 | 76 | @torch.no_grad() 77 | def update_center(self, teacher_output): 78 | self.reduce_center_update(teacher_output) 79 | 80 | @torch.no_grad() 81 | def reduce_center_update(self, teacher_output): 82 | self.updated = False 83 | self.len_teacher_output = len(teacher_output) 84 | self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 85 | if dist.is_initialized(): 86 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) 87 | 88 | @torch.no_grad() 89 | def apply_center_update(self): 90 | if self.updated is False: 91 | world_size = dist.get_world_size() if dist.is_initialized() else 1 92 | 93 | if self.reduce_handle is not None: 94 | self.reduce_handle.wait() 95 | _t = self.async_batch_center / (self.len_teacher_output * world_size) 96 | 97 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) 98 | 99 | self.updated = True 100 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/loss/koleo_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | # import torch.distributed as dist 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class KoLeoLoss(nn.Module): 19 | """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self.pdist = nn.PairwiseDistance(2, eps=1e-8) 24 | 25 | def pairwise_NNs_inner(self, x): 26 | """ 27 | Pairwise nearest neighbors for L2-normalized vectors. 28 | Uses Torch rather than Faiss to remain on GPU. 29 | """ 30 | # parwise dot products (= inverse distance) 31 | dots = torch.mm(x, x.t()) 32 | n = x.shape[0] 33 | dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 34 | # max inner prod -> min distance 35 | _, I = torch.max(dots, dim=1) # noqa: E741 36 | return I 37 | 38 | def forward(self, student_output, eps=1e-8): 39 | """ 40 | Args: 41 | student_output (BxD): backbone output of student 42 | """ 43 | with torch.cuda.amp.autocast(enabled=False): 44 | student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) 45 | I = self.pairwise_NNs_inner(student_output) # noqa: E741 46 | distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B 47 | loss = -torch.log(distances + eps).mean() 48 | return loss 49 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/eval/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import sys 9 | 10 | from dinov2.eval.knn import get_args_parser as get_knn_args_parser 11 | from dinov2.logging import setup_logging 12 | from dinov2.run.submit import get_args_parser, submit_jobs 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, args): 20 | self.args = args 21 | 22 | def __call__(self): 23 | from dinov2.eval.knn import main as knn_main 24 | 25 | self._setup_args() 26 | knn_main(self.args) 27 | 28 | def checkpoint(self): 29 | import submitit 30 | 31 | logger.info(f"Requeuing {self.args}") 32 | empty = type(self)(self.args) 33 | return submitit.helpers.DelayedSubmission(empty) 34 | 35 | def _setup_args(self): 36 | import submitit 37 | 38 | job_env = submitit.JobEnvironment() 39 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 40 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 41 | logger.info(f"Args: {self.args}") 42 | 43 | 44 | def main(): 45 | description = "Submitit launcher for DINOv2 k-NN evaluation" 46 | knn_args_parser = get_knn_args_parser(add_help=False) 47 | parents = [knn_args_parser] 48 | args_parser = get_args_parser(description=description, parents=parents) 49 | args = args_parser.parse_args() 50 | 51 | setup_logging() 52 | 53 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 54 | submit_jobs(Evaluator, args, name="dinov2:knn") 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/eval/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import sys 9 | 10 | from dinov2.eval.linear import get_args_parser as get_linear_args_parser 11 | from dinov2.logging import setup_logging 12 | from dinov2.run.submit import get_args_parser, submit_jobs 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, args): 20 | self.args = args 21 | 22 | def __call__(self): 23 | from dinov2.eval.linear import main as linear_main 24 | 25 | self._setup_args() 26 | linear_main(self.args) 27 | 28 | def checkpoint(self): 29 | import submitit 30 | 31 | logger.info(f"Requeuing {self.args}") 32 | empty = type(self)(self.args) 33 | return submitit.helpers.DelayedSubmission(empty) 34 | 35 | def _setup_args(self): 36 | import submitit 37 | 38 | job_env = submitit.JobEnvironment() 39 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 40 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 41 | logger.info(f"Args: {self.args}") 42 | 43 | 44 | def main(): 45 | description = "Submitit launcher for DINOv2 linear evaluation" 46 | linear_args_parser = get_linear_args_parser(add_help=False) 47 | parents = [linear_args_parser] 48 | args_parser = get_args_parser(description=description, parents=parents) 49 | args = args_parser.parse_args() 50 | 51 | setup_logging() 52 | 53 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 54 | submit_jobs(Evaluator, args, name="dinov2:linear") 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/eval/log_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import sys 9 | 10 | from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser 11 | from dinov2.logging import setup_logging 12 | from dinov2.run.submit import get_args_parser, submit_jobs 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, args): 20 | self.args = args 21 | 22 | def __call__(self): 23 | from dinov2.eval.log_regression import main as log_regression_main 24 | 25 | self._setup_args() 26 | log_regression_main(self.args) 27 | 28 | def checkpoint(self): 29 | import submitit 30 | 31 | logger.info(f"Requeuing {self.args}") 32 | empty = type(self)(self.args) 33 | return submitit.helpers.DelayedSubmission(empty) 34 | 35 | def _setup_args(self): 36 | import submitit 37 | 38 | job_env = submitit.JobEnvironment() 39 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 40 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 41 | logger.info(f"Args: {self.args}") 42 | 43 | 44 | def main(): 45 | description = "Submitit launcher for DINOv2 logistic evaluation" 46 | log_regression_args_parser = get_log_regression_args_parser(add_help=False) 47 | parents = [log_regression_args_parser] 48 | args_parser = get_args_parser(description=description, parents=parents) 49 | args = args_parser.parse_args() 50 | 51 | setup_logging() 52 | 53 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 54 | submit_jobs(Evaluator, args, name="dinov2:logreg") 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import logging 8 | import os 9 | from pathlib import Path 10 | from typing import List, Optional 11 | 12 | import submitit 13 | 14 | from dinov2.utils.cluster import ( 15 | get_slurm_executor_parameters, 16 | get_slurm_partition, 17 | get_user_checkpoint_path, 18 | ) 19 | 20 | 21 | logger = logging.getLogger("dinov2") 22 | 23 | 24 | def get_args_parser( 25 | description: Optional[str] = None, 26 | parents: Optional[List[argparse.ArgumentParser]] = None, 27 | add_help: bool = True, 28 | ) -> argparse.ArgumentParser: 29 | parents = parents or [] 30 | slurm_partition = get_slurm_partition() 31 | parser = argparse.ArgumentParser( 32 | description=description, 33 | parents=parents, 34 | add_help=add_help, 35 | ) 36 | parser.add_argument( 37 | "--ngpus", 38 | "--gpus", 39 | "--gpus-per-node", 40 | default=8, 41 | type=int, 42 | help="Number of GPUs to request on each node", 43 | ) 44 | parser.add_argument( 45 | "--nodes", 46 | "--nnodes", 47 | default=1, 48 | type=int, 49 | help="Number of nodes to request", 50 | ) 51 | parser.add_argument( 52 | "--timeout", 53 | default=2800, 54 | type=int, 55 | help="Duration of the job", 56 | ) 57 | parser.add_argument( 58 | "--partition", 59 | default=slurm_partition, 60 | type=str, 61 | help="Partition where to submit", 62 | ) 63 | parser.add_argument( 64 | "--use-volta32", 65 | action="store_true", 66 | help="Request V100-32GB GPUs", 67 | ) 68 | parser.add_argument( 69 | "--comment", 70 | default="", 71 | type=str, 72 | help="Comment to pass to scheduler, e.g. priority message", 73 | ) 74 | parser.add_argument( 75 | "--exclude", 76 | default="", 77 | type=str, 78 | help="Nodes to exclude", 79 | ) 80 | return parser 81 | 82 | 83 | def get_shared_folder() -> Path: 84 | user_checkpoint_path = get_user_checkpoint_path() 85 | if user_checkpoint_path is None: 86 | raise RuntimeError("Path to user checkpoint cannot be determined") 87 | path = user_checkpoint_path / "experiments" 88 | path.mkdir(exist_ok=True) 89 | return path 90 | 91 | 92 | def submit_jobs(task_class, args, name: str): 93 | if not args.output_dir: 94 | args.output_dir = str(get_shared_folder() / "%j") 95 | 96 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 97 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 98 | 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs["slurm_constraint"] = "volta32gb" 102 | if args.comment: 103 | kwargs["slurm_comment"] = args.comment 104 | if args.exclude: 105 | kwargs["slurm_exclude"] = args.exclude 106 | 107 | executor_params = get_slurm_executor_parameters( 108 | nodes=args.nodes, 109 | num_gpus_per_node=args.ngpus, 110 | timeout_min=args.timeout, # max is 60 * 72 111 | slurm_signal_delay_s=120, 112 | slurm_partition=args.partition, 113 | **kwargs, 114 | ) 115 | executor.update_parameters(name=name, **executor_params) 116 | 117 | task = task_class(args) 118 | job = executor.submit(task) 119 | 120 | logger.info(f"Submitted job_id: {job.job_id}") 121 | str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) 122 | logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") 123 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/run/train/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import sys 9 | 10 | from dinov2.logging import setup_logging 11 | from dinov2.train import get_args_parser as get_train_args_parser 12 | from dinov2.run.submit import get_args_parser, submit_jobs 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | class Trainer(object): 19 | def __init__(self, args): 20 | self.args = args 21 | 22 | def __call__(self): 23 | from dinov2.train import main as train_main 24 | 25 | self._setup_args() 26 | train_main(self.args) 27 | 28 | def checkpoint(self): 29 | import submitit 30 | 31 | logger.info(f"Requeuing {self.args}") 32 | empty = type(self)(self.args) 33 | return submitit.helpers.DelayedSubmission(empty) 34 | 35 | def _setup_args(self): 36 | import submitit 37 | 38 | job_env = submitit.JobEnvironment() 39 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 40 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 41 | logger.info(f"Args: {self.args}") 42 | 43 | 44 | def main(): 45 | description = "Submitit launcher for DINOv2 training" 46 | train_args_parser = get_train_args_parser(add_help=False) 47 | parents = [train_args_parser] 48 | args_parser = get_args_parser(description=description, parents=parents) 49 | args = args_parser.parse_args() 50 | 51 | setup_logging() 52 | 53 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 54 | submit_jobs(Trainer, args, name="dinov2:train") 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .train import get_args_parser, main 7 | from .ssl_meta_arch import SSLMetaArch 8 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | import os 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional 10 | 11 | 12 | class ClusterType(Enum): 13 | AWS = "aws" 14 | FAIR = "fair" 15 | RSC = "rsc" 16 | 17 | 18 | def _guess_cluster_type() -> ClusterType: 19 | uname = os.uname() 20 | if uname.sysname == "Linux": 21 | if uname.release.endswith("-aws"): 22 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 23 | return ClusterType.AWS 24 | elif uname.nodename.startswith("rsc"): 25 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 26 | return ClusterType.RSC 27 | 28 | return ClusterType.FAIR 29 | 30 | 31 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 32 | if cluster_type is None: 33 | return _guess_cluster_type() 34 | 35 | return cluster_type 36 | 37 | 38 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 39 | cluster_type = get_cluster_type(cluster_type) 40 | if cluster_type is None: 41 | return None 42 | 43 | CHECKPOINT_DIRNAMES = { 44 | ClusterType.AWS: "checkpoints", 45 | ClusterType.FAIR: "checkpoint", 46 | ClusterType.RSC: "checkpoint/dino", 47 | } 48 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 49 | 50 | 51 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 52 | checkpoint_path = get_checkpoint_path(cluster_type) 53 | if checkpoint_path is None: 54 | return None 55 | 56 | username = os.environ.get("USER") 57 | assert username is not None 58 | return checkpoint_path / username 59 | 60 | 61 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 62 | cluster_type = get_cluster_type(cluster_type) 63 | if cluster_type is None: 64 | return None 65 | 66 | SLURM_PARTITIONS = { 67 | ClusterType.AWS: "learnlab", 68 | ClusterType.FAIR: "learnlab", 69 | ClusterType.RSC: "learn", 70 | } 71 | return SLURM_PARTITIONS[cluster_type] 72 | 73 | 74 | def get_slurm_executor_parameters( 75 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 76 | ) -> Dict[str, Any]: 77 | # create default parameters 78 | params = { 79 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 80 | "gpus_per_node": num_gpus_per_node, 81 | "tasks_per_node": num_gpus_per_node, # one task per GPU 82 | "cpus_per_task": 10, 83 | "nodes": nodes, 84 | "slurm_partition": get_slurm_partition(cluster_type), 85 | } 86 | # apply cluster-specific adjustments 87 | cluster_type = get_cluster_type(cluster_type) 88 | if cluster_type == ClusterType.AWS: 89 | params["cpus_per_task"] = 12 90 | del params["mem_gb"] 91 | elif cluster_type == ClusterType.RSC: 92 | params["cpus_per_task"] = 12 93 | # set additional parameters / apply overrides 94 | params.update(kwargs) 95 | return params 96 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import logging 8 | import os 9 | 10 | from omegaconf import OmegaConf 11 | 12 | import model_pack.dinoV2.dinov2.distributed as distributed 13 | from model_pack.dinoV2.dinov2.logging import setup_logging 14 | # from model_pack.dinoV2.dinov2.utils import utils 15 | from model_pack.dinoV2.dinov2.configs import dinov2_default_config 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def apply_scaling_rules_to_cfg(cfg): # to fix 22 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 23 | base_lr = cfg.optim.base_lr 24 | cfg.optim.lr = base_lr 25 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 26 | # logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 27 | else: 28 | raise NotImplementedError 29 | return cfg 30 | 31 | 32 | def write_config(cfg, output_dir, name="config.yaml"): 33 | # logger.info(OmegaConf.to_yaml(cfg)) 34 | saved_cfg_path = os.path.join(output_dir, name) 35 | with open(saved_cfg_path, "w") as f: 36 | OmegaConf.save(config=cfg, f=f) 37 | return saved_cfg_path 38 | 39 | 40 | def get_cfg_from_args(args): 41 | args.output_dir = os.path.abspath(args.output_dir) 42 | args.opts += [f"train.output_dir={args.output_dir}"] 43 | default_cfg = OmegaConf.create(dinov2_default_config) 44 | cfg = OmegaConf.load(args.config_file) 45 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 46 | return cfg 47 | 48 | 49 | def default_setup(args): 50 | # distributed.enable(overwrite=True) 51 | # seed = getattr(args, "seed", 0) 52 | # rank = distributed.get_global_rank() 53 | 54 | global logger 55 | setup_logging(output=args.output_dir, level=logging.INFO) 56 | logger = logging.getLogger("dinov2") 57 | 58 | # utils.fix_random_seeds(seed + rank) 59 | # logger.info("git:\n {}\n".format(utils.get_sha())) 60 | # logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 61 | 62 | def setup(args): 63 | """ 64 | Create configs and perform basic setups. 65 | """ 66 | cfg = get_cfg_from_args(args) 67 | os.makedirs(args.output_dir, exist_ok=True) 68 | default_setup(args) 69 | apply_scaling_rules_to_cfg(cfg) 70 | write_config(cfg, args.output_dir) 71 | return cfg 72 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from typing import Dict, Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | TypeSpec = Union[str, np.dtype, torch.dtype] 14 | 15 | 16 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 17 | np.dtype("bool"): torch.bool, 18 | np.dtype("uint8"): torch.uint8, 19 | np.dtype("int8"): torch.int8, 20 | np.dtype("int16"): torch.int16, 21 | np.dtype("int32"): torch.int32, 22 | np.dtype("int64"): torch.int64, 23 | np.dtype("float16"): torch.float16, 24 | np.dtype("float32"): torch.float32, 25 | np.dtype("float64"): torch.float64, 26 | np.dtype("complex64"): torch.complex64, 27 | np.dtype("complex128"): torch.complex128, 28 | } 29 | 30 | 31 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 32 | if isinstance(dtype, torch.dtype): 33 | return dtype 34 | if isinstance(dtype, str): 35 | dtype = np.dtype(dtype) 36 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 37 | return _NUMPY_TO_TORCH_DTYPE[dtype] 38 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from collections import defaultdict 7 | import logging 8 | 9 | 10 | logger = logging.getLogger("dinov2") 11 | 12 | 13 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 14 | """ 15 | Calculate lr decay rate for different ViT blocks. 16 | Args: 17 | name (string): parameter name. 18 | lr_decay_rate (float): base lr decay rate. 19 | num_layers (int): number of ViT blocks. 20 | Returns: 21 | lr decay rate for the given parameter. 22 | """ 23 | layer_id = num_layers + 1 24 | if name.startswith("backbone") or force_is_backbone: 25 | if ( 26 | ".pos_embed" in name 27 | or ".patch_embed" in name 28 | or ".mask_token" in name 29 | or ".cls_token" in name 30 | or ".register_tokens" in name 31 | ): 32 | layer_id = 0 33 | elif force_is_backbone and ( 34 | "pos_embed" in name 35 | or "patch_embed" in name 36 | or "mask_token" in name 37 | or "cls_token" in name 38 | or "register_tokens" in name 39 | ): 40 | layer_id = 0 41 | elif ".blocks." in name and ".residual." not in name: 42 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 43 | elif chunked_blocks and "blocks." in name and "residual." not in name: 44 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 45 | elif "blocks." in name and "residual." not in name: 46 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 47 | 48 | return lr_decay_rate ** (num_layers + 1 - layer_id) 49 | 50 | 51 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 52 | chunked_blocks = False 53 | if hasattr(model, "n_blocks"): 54 | logger.info("chunked fsdp") 55 | n_blocks = model.n_blocks 56 | chunked_blocks = model.chunked_blocks 57 | elif hasattr(model, "blocks"): 58 | logger.info("first code branch") 59 | n_blocks = len(model.blocks) 60 | elif hasattr(model, "backbone"): 61 | logger.info("second code branch") 62 | n_blocks = len(model.backbone.blocks) 63 | else: 64 | logger.info("else code branch") 65 | n_blocks = 0 66 | all_param_groups = [] 67 | 68 | for name, param in model.named_parameters(): 69 | name = name.replace("_fsdp_wrapped_module.", "") 70 | if not param.requires_grad: 71 | continue 72 | decay_rate = get_vit_lr_decay_rate( 73 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 74 | ) 75 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 76 | 77 | if "last_layer" in name: 78 | d.update({"is_last_layer": True}) 79 | 80 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 81 | d.update({"wd_multiplier": 0.0}) 82 | 83 | if "patch_embed" in name: 84 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 85 | 86 | all_param_groups.append(d) 87 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 88 | 89 | return all_param_groups 90 | 91 | 92 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 93 | fused_params_groups = defaultdict(lambda: {"params": []}) 94 | for d in all_params_groups: 95 | identifier = "" 96 | for k in keys: 97 | identifier += k + str(d[k]) + "_" 98 | 99 | for k in keys: 100 | fused_params_groups[identifier][k] = d[k] 101 | fused_params_groups[identifier]["params"].append(d["params"]) 102 | 103 | return fused_params_groups.values() 104 | -------------------------------------------------------------------------------- /model_pack/dinoV2/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import random 9 | import subprocess 10 | from urllib.parse import urlparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 21 | 22 | 23 | 24 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 25 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 26 | else: 27 | state_dict = torch.load(pretrained_weights, map_location="cpu") 28 | 29 | # del_list = [] 30 | # for k, v in state_dict.items(): 31 | # if 'depth_head' in k: 32 | # del_list.append(k) 33 | # for i in del_list: 34 | # del state_dict[i] 35 | # state_dict = {k.replace("pretrained.", ""): v for k, v in state_dict.items()} 36 | # torch.save(state_dict,'../toolkit/models/DepthAny_body_vitl.pth') 37 | # exit() 38 | 39 | if checkpoint_key is not None and checkpoint_key in state_dict: 40 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 41 | state_dict = state_dict[checkpoint_key] 42 | # remove `module.` prefix 43 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 44 | # remove `backbone.` prefix induced by multicrop wrapper 45 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 46 | msg = model.load_state_dict(state_dict, strict=False) 47 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 48 | 49 | 50 | def fix_random_seeds(seed=31): 51 | """ 52 | Fix random seeds. 53 | """ 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | np.random.seed(seed) 57 | random.seed(seed) 58 | 59 | 60 | def get_sha(): 61 | cwd = os.path.dirname(os.path.abspath(__file__)) 62 | 63 | def _run(command): 64 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 65 | 66 | sha = "N/A" 67 | diff = "clean" 68 | branch = "N/A" 69 | try: 70 | sha = _run(["git", "rev-parse", "HEAD"]) 71 | subprocess.check_output(["git", "diff"], cwd=cwd) 72 | diff = _run(["git", "diff-index", "HEAD"]) 73 | diff = "has uncommitted changes" if diff else "clean" 74 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 75 | except Exception: 76 | pass 77 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 78 | return message 79 | 80 | 81 | class CosineScheduler(object): 82 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 83 | super().__init__() 84 | self.final_value = final_value 85 | self.total_iters = total_iters 86 | 87 | freeze_schedule = np.zeros((freeze_iters)) 88 | 89 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 90 | 91 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 92 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 93 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 94 | 95 | assert len(self.schedule) == self.total_iters 96 | 97 | def __getitem__(self, it): 98 | if it >= self.total_iters: 99 | return self.final_value 100 | else: 101 | return self.schedule[it] 102 | 103 | 104 | def has_batchnorms(model): 105 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 106 | for name, module in model.named_modules(): 107 | if isinstance(module, bn_types): 108 | return True 109 | return False 110 | -------------------------------------------------------------------------------- /model_pack/dinoV2/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 8 | from dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg 9 | from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc 10 | from dinov2.hub.classifiers import dinov2_vitb14_reg_lc, dinov2_vitg14_reg_lc, dinov2_vitl14_reg_lc, dinov2_vits14_reg_lc 11 | from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld 12 | from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd 13 | 14 | 15 | dependencies = ["torch"] 16 | -------------------------------------------------------------------------------- /model_pack/dinoV2/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.pylint.master] 5 | persistent = false 6 | score = false 7 | 8 | [tool.pylint.messages_control] 9 | disable = "all" 10 | enable = [ 11 | "miscellaneous", 12 | "similarities", 13 | ] 14 | 15 | [tool.pylint.similarities] 16 | ignore-comments = true 17 | ignore-docstrings = true 18 | ignore-imports = true 19 | min-similarity-lines = 8 20 | 21 | [tool.pylint.reports] 22 | reports = false 23 | 24 | [tool.pylint.miscellaneous] 25 | notes = [ 26 | "FIXME", 27 | "XXX", 28 | "TODO", 29 | ] 30 | -------------------------------------------------------------------------------- /model_pack/dinoV2/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==22.6.0 2 | flake8==5.0.4 3 | pylint==2.15.0 4 | -------------------------------------------------------------------------------- /model_pack/dinoV2/requirements-extras.txt: -------------------------------------------------------------------------------- 1 | mmcv-full==1.5.0 2 | mmsegmentation==0.27.0 3 | -------------------------------------------------------------------------------- /model_pack/dinoV2/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==2.0.0 3 | torchvision==0.15.0 4 | omegaconf 5 | torchmetrics==0.10.3 6 | fvcore 7 | iopath 8 | xformers==0.0.18 9 | submitit 10 | --extra-index-url https://pypi.nvidia.com 11 | cuml-cu11 12 | -------------------------------------------------------------------------------- /model_pack/dinoV2/scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -n "$1" ]; then 4 | echo "linting \"$1\"" 5 | fi 6 | 7 | echo "running black" 8 | if [ -n "$1" ]; then 9 | black "$1" 10 | else 11 | black dinov2 12 | fi 13 | 14 | echo "running flake8" 15 | if [ -n "$1" ]; then 16 | flake8 "$1" 17 | else 18 | flake8 19 | fi 20 | 21 | echo "running pylint" 22 | if [ -n "$1" ]; then 23 | pylint "$1" 24 | else 25 | pylint dinov2 26 | fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /model_pack/dinoV2/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E501,W503 4 | per-file-ignores = 5 | __init__.py:F401 6 | hubconf.py:F401 7 | exclude = 8 | venv 9 | -------------------------------------------------------------------------------- /model_pack/dinoV2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from pathlib import Path 7 | import re 8 | from typing import List, Tuple 9 | 10 | from setuptools import setup, find_packages 11 | 12 | 13 | NAME = "dinov2" 14 | DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method." 15 | 16 | URL = "https://github.com/facebookresearch/dinov2" 17 | AUTHOR = "FAIR" 18 | REQUIRES_PYTHON = ">=3.9.0" 19 | HERE = Path(__file__).parent 20 | 21 | 22 | try: 23 | with open(HERE / "README.md", encoding="utf-8") as f: 24 | long_description = "\n" + f.read() 25 | except FileNotFoundError: 26 | long_description = DESCRIPTION 27 | 28 | 29 | def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: 30 | requirements = [] 31 | extra_indices = [] 32 | with open(path) as f: 33 | for line in f.readlines(): 34 | line = line.rstrip("\r\n") 35 | if line.startswith("--extra-index-url "): 36 | extra_indices.append(line[18:]) 37 | continue 38 | requirements.append(line) 39 | return requirements, extra_indices 40 | 41 | 42 | def get_package_version() -> str: 43 | with open(HERE / "dinov2/__init__.py") as f: 44 | result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 45 | if result: 46 | return result.group(1) 47 | raise RuntimeError("Can't get package version") 48 | 49 | 50 | requirements, extra_indices = get_requirements() 51 | version = get_package_version() 52 | dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") 53 | extras_requirements, _ = get_requirements(HERE / "requirements-extras.txt") 54 | 55 | 56 | setup( 57 | name=NAME, 58 | version=version, 59 | description=DESCRIPTION, 60 | long_description=long_description, 61 | long_description_content_type="text/markdown", 62 | author=AUTHOR, 63 | python_requires=REQUIRES_PYTHON, 64 | url=URL, 65 | packages=find_packages(), 66 | package_data={ 67 | "": ["*.yaml"], 68 | }, 69 | install_requires=requirements, 70 | extras_require={ 71 | "dev": dev_requirements, 72 | "extras": extras_requirements, 73 | }, 74 | dependency_links=extra_indices, 75 | install_package_data=True, 76 | license="Apache", 77 | license_files=("LICENSE",), 78 | classifiers=[ 79 | # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py 80 | "Development Status :: 3 - Alpha", 81 | "Intended Audience :: Developers", 82 | "Intended Audience :: Science/Research", 83 | "License :: OSI Approved :: Apache Software License", 84 | "Programming Language :: Python :: 3.9", 85 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 86 | "Topic :: Software Development :: Libraries :: Python Modules", 87 | ], 88 | ) 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apex==0.9.10dev 2 | ConfigArgParse==1.7 3 | cuml==0.6.1.post1 4 | fvcore==0.1.5.post20221221 5 | imageio==2.33.0 6 | matplotlib==3.8.2 7 | mmcv==2.1.0 8 | mmcv_full==1.7.2 9 | mmdet==3.3.0 10 | mmsegmentation==1.2.2 11 | numpy==1.24.1 12 | omegaconf==2.3.0 13 | opencv_python==4.8.1.78 14 | opt_einsum==3.3.0 15 | Pillow==10.1.0 16 | Pillow==9.3.0 17 | Pillow==10.3.0 18 | pytorch_lightning==2.1.2 19 | PyYAML==6.0.1 20 | PyYAML==6.0.1 21 | scipy==1.13.0 22 | setuptools==60.2.0 23 | skimage==0.22.0 24 | submitit==1.5.1 25 | timm==0.5.4 26 | torch==2.1.0+cu118 27 | torchmetrics==1.2.1 28 | torchvision==0.16.0 29 | tqdm==4.65.2 30 | xformers==0.0.22.post7+cu118 31 | -------------------------------------------------------------------------------- /toolkit/args/args_default.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def get_opts(): 4 | parser = configargparse.ArgumentParser() 5 | 6 | # dataset options 7 | parser.add_argument('--dataset', default = 'KITTI2015', type=str) 8 | parser.add_argument('--dataset_type', type=str, default='train', choices=['train','test'],help = 'use train or test dataset') 9 | parser.add_argument('--if_use_valid', type=bool, default=False, help = "if use validation dataset, effects on the generate_file_path") 10 | parser.add_argument('--val_dataset', default = 'MiddEval3H', type=str) 11 | parser.add_argument('--val_dataset_type', type=str, default='train', choices=['train','test']) 12 | parser.add_argument('--keep_size', default=False, type=bool, help='do not resize the input image') 13 | parser.add_argument('--max_disp', type=int, default=192, help="max_disparity of the datasets") 14 | parser.add_argument('--max_disp_sr', type=int, default=192, help="max_disparity when estimate the disparity") 15 | parser.add_argument('--save_name', type=str, default='Delete', help="name for saving the results") 16 | parser.add_argument('--resize', type=float, default=None, help="resize the image") 17 | 18 | # model options 19 | parser.add_argument('--network', type=str, default='ViTASIGEV') # , choices=['AANet','PSMNet','RAFT_Stereo','DFM','LacGwc','Unimatch','delete','BGNet','VGG','SGBM','SRP','FBS'] 20 | parser.add_argument('--ckpt_path', type=str, default='models/ViTASIGEV/KITTI.pth', help='pretrained checkpoint path to load from YoYo') 21 | parser.add_argument('--pre_trained', type=bool, default=False, help='pretrained checkpoint path to load from github') 22 | 23 | # training options 24 | parser.add_argument('--inference_type', type=str, default='evaluate', help='main inference, and has effect on the aug_config', choices=['evaluate','train']) 25 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') 26 | parser.add_argument('--min_lr', type=float, default=2e-5, help='minimum learning rate') 27 | # parser.add_argument('--freeze_bn', default=True, type=bool) 28 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 29 | parser.add_argument('--num_steps', type=int, default=2000, help='number of training steps') 30 | parser.add_argument('--epoch_steps', type=int, default=800, help='number of training steps of each epoch') 31 | parser.add_argument('--num_workers', type=int, default=16, help='number of workers') 32 | parser.add_argument('--epoch_size', type=int, default=400, help='number of training epochs') 33 | parser.add_argument('--devices', type=int, nargs='+', default = [0], help='GPU devices used') 34 | parser.add_argument('--schedule', type=str, default='Cycle', help='select lr schedule') 35 | parser.add_argument('--this_epoch', type=int, default='-1', help='if resume training, use this epoch') 36 | parser.add_argument('--resume', type=bool, default=False, help='if resume the training') 37 | parser.add_argument('--resume_model', type=bool, default=False, help='if resume the training') 38 | parser.add_argument('--hparams_dir', type=str, default=None, help="name for saving the results") 39 | 40 | # ViTAS options 41 | parser.add_argument('--ViTAS_dic', type=dict, default = {'VFM_type':None,'ViTAS_model':None,'ViTAS_unfreeze':'5','ViTAS_hooks':None,'wo_SDM':None,'wo_fuse':None,'ViTAS_fuse':'PAFM','ViTAS_fuse_patch':None,'ViTAS_fuse_weight':None,'CA_layers':None}, help='args for ViTAS') 42 | 43 | return parser.parse_args() -------------------------------------------------------------------------------- /toolkit/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import StereoDataset 2 | -------------------------------------------------------------------------------- /toolkit/function/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calcu_EPE(pre,gt): 5 | ans = torch.mean(torch.abs(gt - pre)) 6 | return ans 7 | 8 | def calcu_PEP(pre,gt,thr=1): 9 | abs_diff = torch.abs(gt - pre) 10 | num_base = torch.sum(abs_diff>=0) 11 | num_error = torch.sum(abs_diff>=thr) 12 | ans = num_error/num_base 13 | return ans 14 | 15 | def calcu_D1all(pre,gt): 16 | abs_diff = torch.abs(gt - pre) 17 | num_base = torch.sum(abs_diff>=0) 18 | error = torch.ones(abs_diff.shape) 19 | error[abs_diff < gt*0.05] = 0 20 | error[abs_diff < 3] = 0 21 | num_error = torch.sum(error) 22 | ans = num_error/num_base 23 | return ans -------------------------------------------------------------------------------- /toolkit/function/models.py: -------------------------------------------------------------------------------- 1 | from ModelCreated.ViTASIGEV.ViTASIGEV import ViTASIGEVModel 2 | 3 | def prepare_model(hparams,specific_model=None): 4 | if hparams.network == 'ViTASIGEV': 5 | model = load_ViTASIGEV_model(hparams) 6 | return model 7 | 8 | def load_ViTASIGEV_model(hparams): 9 | model = ViTASIGEVModel(hparams.ViTAS_dic) 10 | return model 11 | -------------------------------------------------------------------------------- /toolkit/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | torch.set_float32_matmul_precision('high') #highest,high,medium 5 | torch.backends.cudnn.benchmark = True # # Accelate training 6 | from torch.utils.data import DataLoader 7 | 8 | import warnings 9 | warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") 10 | warnings.filterwarnings("ignore", ".*exists and is not empty.") 11 | warnings.filterwarnings("ignore", ".*logging on epoch level in distributed setting*") 12 | warnings.filterwarnings("ignore", ".*RoPE2D, using a slow pytorch version instead") 13 | 14 | import sys 15 | sys.path.append('..') 16 | from toolkit.data_loader.dataset_function import generate_file_lists 17 | from toolkit.data_loader.dataloader import prepare_dataset,dataloader_customization,optimizer_customization 18 | from toolkit.args.args_default import get_opts 19 | from toolkit.torch_lightning.pl_modules.train import trainer_func 20 | from toolkit.torch_lightning.pl_modules.evaluate import evaluate_func 21 | 22 | # For reproducibility 23 | # torch.manual_seed(192) 24 | # torch.cuda.manual_seed(192) 25 | # np.random.seed(192) 26 | # random.seed(192) 27 | 28 | def inference(hparams): 29 | 30 | ###################### 31 | # prepare dataloader # 32 | ###################### 33 | hparams,aug_config,valid_aug_config = dataloader_customization(hparams) 34 | file_path_dic = generate_file_lists(dataset = hparams.dataset,if_train=hparams.dataset_type=='train',method='gt',save_method=hparams.save_name) 35 | dataset,n_img = prepare_dataset(file_path_dic,aug_config=aug_config) 36 | hparams = optimizer_customization(hparams,n_img) 37 | if hparams.if_use_valid: 38 | valid_file_path_dic = generate_file_lists(dataset = hparams.val_dataset,if_train=hparams.val_dataset_type=='train',method='gt',save_method='delete') 39 | valid_dataset,_ = prepare_dataset(valid_file_path_dic,aug_config=valid_aug_config) 40 | else: 41 | valid_dataset = None 42 | 43 | ############################################ 44 | # load model and select inference function # 45 | ############################################ 46 | if hparams.inference_type == 'train': 47 | inference = trainer_func 48 | elif hparams.inference_type == 'evaluate': 49 | inference = evaluate_func 50 | ########################## 51 | # run inference function # 52 | ########################## 53 | if 'train' in hparams.inference_type: 54 | inference(hparams,dataset,valid_dataset) 55 | elif 'evaluate' in hparams.inference_type: 56 | test_dataloader = DataLoader(dataset, batch_size= 1, shuffle= False, num_workers= 1, drop_last=False) 57 | inference(hparams,test_dataloader) 58 | 59 | if __name__ == '__main__': 60 | 61 | hparams = get_opts() 62 | hparams.devices = [0] 63 | 64 | # ViTASIGEV evaluate 65 | hparams.inference_type = 'evaluate' 66 | hparams.dataset = 'KITTI2015' 67 | hparams.network = 'ViTASIGEV' 68 | hparams.save_name = 'ViTASIGEV_benchmark' 69 | hparams.ckpt_path = 'models/ViTASIGEV/KITTI.pth' 70 | hparams.ViTAS_dic['VFM_type'] = 'DINOv2' 71 | inference(hparams) 72 | 73 | # ViTASIGEV train 74 | hparams.inference_type = 'train' 75 | hparams.if_use_valid = True 76 | hparams.dataset = 'KITTI2015' 77 | hparams.val_dataset = 'KITTI2012' 78 | hparams.network = 'ViTASIGEV' 79 | hparams.save_name = 'ViTASIGEV_benchmark' 80 | hparams.ckpt_path = 'models/ViTASIGEV/KITTI.pth' 81 | hparams.ViTAS_dic['ViTAS_fuse'] = 'PAFM' # ['SDFA','PAFM','VFM'] 82 | hparams.batch_size = 2 83 | hparams.num_workers = 2 84 | hparams.epoch_size = 100 85 | inference(hparams) 86 | 87 | 88 | -------------------------------------------------------------------------------- /toolkit/torch_lightning/data_modules/custom.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data import DataLoader, RandomSampler 3 | 4 | class CustomDataModule(LightningDataModule): 5 | 6 | def __init__(self, hparams,train_dataset,valid_dataset=None): 7 | super().__init__() 8 | self.save_hyperparameters(hparams) 9 | self.train_dataset = train_dataset 10 | self.valid_dataset = valid_dataset 11 | 12 | def prepare_data(self): 13 | pass 14 | 15 | def setup(self, stage=None): 16 | print('{} samples found for training'.format(len(self.train_dataset))) 17 | if not self.valid_dataset is None: 18 | print('{} samples found for validation'.format(len(self.valid_dataset))) 19 | 20 | def train_dataloader(self): 21 | # sampler = RandomSampler(self.train_dataset, 22 | # replacement=True, 23 | # num_samples=self.hparams['batch_size'] * self.hparams['epoch_size']) 24 | # sampler = RandomSampler(self.train_dataset, 25 | # replacement=False) 26 | # return DataLoader(self.train_dataset, 27 | # sampler=sampler, 28 | # num_workers=self.hparams['num_workers'], 29 | # batch_size=self.hparams['batch_size'], 30 | # pin_memory=True, 31 | # persistent_workers=True) 32 | return DataLoader(self.train_dataset, 33 | shuffle=True, 34 | num_workers=self.hparams['num_workers'], 35 | batch_size=self.hparams['batch_size'], 36 | pin_memory=False, 37 | persistent_workers=False) 38 | 39 | def val_dataloader(self): 40 | if self.valid_dataset is None: 41 | return None 42 | return DataLoader(self.valid_dataset, 43 | shuffle=False, 44 | num_workers=self.hparams['num_workers'], 45 | # batch_size=self.hparams['batch_size'], 46 | batch_size=1, 47 | # num_workers=1, 48 | pin_memory=True, 49 | persistent_workers=True) -------------------------------------------------------------------------------- /toolkit/torch_lightning/lightning_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | 5 | 6 | 7 | def schedule_select(optimizer,hparams): 8 | if hparams.schedule == 'Cycle': # for large dataset 9 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=hparams.min_lr, max_lr=hparams.lr, step_size_up=int(hparams.epoch_steps/500 + 5), step_size_down=int(hparams.epoch_steps*0.4), cycle_momentum=False,mode='triangular2', last_epoch = hparams.this_epoch) # 20,6000;2,200 10 | print('lr_schedule Cycle: base_lr',hparams.min_lr,'; max_lr',hparams.lr) 11 | elif hparams.schedule == 'OneCycle': # for small dataset 12 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 13 | optimizer, max_lr=hparams.lr, 14 | total_steps=hparams.num_steps + 50, 15 | pct_start=0.03, 16 | cycle_momentum=False, 17 | anneal_strategy='cos', 18 | last_epoch = hparams.this_epoch, 19 | # initial_lr = hparams.lr/25, 20 | ) 21 | print('lr_schedule OneCycle: max_lr',hparams.lr) 22 | lr_scheduler = { 23 | 'scheduler': scheduler, 24 | 'name': 'my_logging_lr', 25 | 'interval':'step' 26 | } 27 | return lr_scheduler 28 | 29 | def hparam_resume(hparams,evaluate = False): 30 | # if hparams.hparams_dir is None: # search for hparam file 31 | hparams.hparams_dir = '/'.join(hparams.ckpt_path.split('/')[:-1])+'/hparams.yaml' 32 | # print(hparams.hparams_dir) 33 | with open(hparams.hparams_dir, 'r') as f: 34 | hparams_ = yaml.unsafe_load(f)['hparams'] 35 | 36 | # set exception 37 | if evaluate: 38 | exception_ = ['resume','resume_model','ViTAS_dic','devices','ckpt_path','inference_type','num_workers','dataset','dataset_type','if_use_valid','val_dataset','val_dataset_type','save_name'] 39 | else: 40 | if not hparams.resume_model: # resume all 41 | exception_ = ['resume','resume_model','ViTAS_dic','devices','ckpt_path','inference_type','num_workers','dataset','dataset_type','if_use_valid','val_dataset','val_dataset_type','save_name','batch_size','num_steps','epoch_steps','epoch_size','schedule'] 42 | else: # only resume the model 43 | exception_ = ['resume','resume_model','ViTAS_dic','devices','ckpt_path','inference_type','num_workers','batch_size','num_steps','epoch_steps','epoch_size','schedule'] 44 | 45 | # merge the hparams 46 | hparams_ = vars(hparams_) 47 | hparams = vars(hparams) 48 | for k in hparams.keys(): 49 | if k in hparams_.keys() and not k in exception_: 50 | hparams[k] = hparams_[k] 51 | hparams['ViTAS_dic'].update(hparams_['ViTAS_dic']) 52 | if not hparams['resume_model']: # resume all, continue training 53 | last_epoch = int(hparams['ckpt_path'].split('epoch=')[1].split('-')[0]) 54 | hparams['this_epoch'] = last_epoch*hparams['epoch_steps'] 55 | print('resume training from epoch {}, step {}'.format(last_epoch,hparams['this_epoch'])) 56 | else: # resume model and re-start the training 57 | hparams['this_epoch'] = -1 58 | return argparse.Namespace(**hparams) 59 | -------------------------------------------------------------------------------- /toolkit/torch_lightning/pl_modules/evaluate.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from toolkit.torch_lightning.pl_modules.ViTASIGEV import ViTASIGEV 3 | from toolkit.torch_lightning.lightning_function import hparam_resume 4 | 5 | def evaluate_func(hparams,test_dataloader): 6 | 7 | if hparams.network == 'ViTASIGEV': 8 | system = ViTASIGEV 9 | else: 10 | raise ValueError('Invalid network type') 11 | 12 | # if not hparams.hparams_dir is None: 13 | # with open(hparams.hparams_dir, 'r') as f: 14 | # hparams = yaml.unsafe_load(f)['hparams'] 15 | 16 | 17 | if hparams.ckpt_path is not None: 18 | if hparams.resume or hparams.resume_model: 19 | hparams = hparam_resume(hparams,evaluate=True) 20 | print('load lightning pre-trained model from {}'.format(hparams.ckpt_path)) 21 | # system = system.load_from_checkpoint(hparams.ckpt_path,**{'hparams':hparams}) 22 | system = system.load_from_checkpoint(hparams.ckpt_path,**{'hparams':hparams},map_location='cpu') # map_location={'cuda:1': 'cuda:0'} 23 | else: 24 | system = system(hparams = hparams) 25 | 26 | # set up trainer 27 | trainer = Trainer(accelerator='gpu', 28 | devices = hparams.devices,) 29 | 30 | # time0=time.time() 31 | trainer.test(system, test_dataloader) 32 | # print((time.time()-time0)/97) 33 | -------------------------------------------------------------------------------- /toolkit/torch_lightning/pl_modules/train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | from pytorch_lightning.loggers import TensorBoardLogger 4 | from pytorch_lightning.callbacks import LearningRateMonitor 5 | # from pytorch_lightning.strategies import DDPStrategy 6 | from toolkit.torch_lightning.pl_modules.ViTASIGEV import ViTASIGEV 7 | from toolkit.torch_lightning.data_modules.custom import CustomDataModule 8 | from toolkit.torch_lightning.lightning_function import hparam_resume 9 | import torch 10 | 11 | def trainer_func(hparams,train_dataset,valid_dataset=None): 12 | 13 | if hparams.network == 'ViTASIGEV': 14 | system = ViTASIGEV 15 | else: 16 | raise ValueError('Invalid network type') 17 | 18 | # pl data module 19 | dm = CustomDataModule(hparams,train_dataset,valid_dataset) 20 | 21 | # pl logger 22 | logger = TensorBoardLogger( 23 | save_dir="torch_lightning/ckpts", 24 | name=hparams.save_name, 25 | default_hp_metric=False 26 | ) 27 | print('=> log save in: ckpts/{}/version_{:d}'.format(hparams.save_name, logger.version)) 28 | 29 | # save checkpoints 30 | ckpt_dir = 'torch_lightning/ckpts/{}/version_{:d}'.format( 31 | hparams.save_name, logger.version) 32 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir, 33 | # filename='{epoch}-{loss_epoch:.4f}', 34 | filename='{epoch}-{loss_epoch:.4f}-{valid_epoch_EPE:.4f}-{valid_epoch_D1:.4f}', 35 | monitor = 'valid_epoch_EPE', # loss_epoch,valid_epoch_EPE 36 | mode='min', 37 | save_last=False, 38 | save_weights_only=True, 39 | every_n_epochs = 1, 40 | save_top_k=8, # -1: save for every epoch 41 | ) 42 | 43 | 44 | # restore from previous checkpoints 45 | if hparams.ckpt_path is not None: 46 | if hparams.resume or hparams.resume_model: 47 | hparams = hparam_resume(hparams) 48 | hparams.this_epoch = (int(hparams.ckpt_path.split('epoch=')[1].split('-')[0]))*hparams.epoch_steps 49 | print('load lightning pre-trained model from {}'.format(hparams.ckpt_path)) 50 | # system = system.load_from_checkpoint(hparams.ckpt_path,**{'hparams':hparams}) 51 | # system = system.load_from_checkpoint(hparams.ckpt_path,**{'hparams':hparams},map_location='cuda:0') 52 | system = system.load_from_checkpoint(hparams.ckpt_path,**{'hparams':hparams},map_location='cpu') 53 | else: 54 | system = system(hparams = hparams) 55 | 56 | lr_monitor = LearningRateMonitor(logging_interval='step') 57 | 58 | gpu_count = torch.cuda.device_count() 59 | if len(hparams.devices) == 1: 60 | strategy = 'auto' # auto 61 | elif len(hparams.devices) > 1: 62 | strategy = 'ddp_find_unused_parameters_true' 63 | else: 64 | raise NotImplementedError('none GPU detected') 65 | print('{} GPUs detected, use {} strategy and {} devices'.format(gpu_count,strategy,hparams.devices)) 66 | 67 | # set up trainer 68 | trainer = Trainer( 69 | accelerator='gpu', 70 | max_epochs=hparams.epoch_size, 71 | # limit_val_batches=200 if hparams.val_mode == 'photo' else 1.0, 72 | # limit_val_batches=1, 73 | num_sanity_val_steps = 2, 74 | log_every_n_steps = 50, # default = 50 75 | val_check_interval = 1.0, # validate every n training epoch, default=1.0 76 | callbacks=[checkpoint_callback,lr_monitor], 77 | devices = hparams.devices, 78 | strategy = strategy, 79 | logger=logger, 80 | benchmark=True, 81 | # reload_dataloaders_every_n_epochs = 5, 82 | ) 83 | 84 | trainer.fit(system, dm) 85 | 86 | --------------------------------------------------------------------------------