├── datasets ├── __init__.py ├── svd │ ├── __init__.py │ ├── core.py │ └── eval.py ├── gld.py └── oxford_paris.py ├── models ├── __init__.py ├── delg │ ├── __init__.py │ ├── r50_delg_config.yaml │ ├── r101_delg_config.yaml │ ├── model.py │ ├── config.py │ ├── net.py │ └── resnet.py ├── dolg │ ├── __init__.py │ ├── dolg_config.yaml │ ├── config.py │ ├── net.py │ ├── model.py │ └── resnet.py ├── solar │ ├── __init__.py │ └── networks.py ├── gem_pooling.py ├── pca_layer.py ├── distill_model.py └── teachers.py ├── assets └── teaser.png ├── requirements.txt ├── config └── svd.yaml ├── README.md ├── .gitignore ├── utils.py ├── loss.py ├── svd_eval.py ├── metric.py ├── gld_pca_learn.py ├── svd_pca_learn.py ├── LICENSE ├── oxford_paris_eval.py ├── svd_distill.py └── gld_distill.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/delg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/dolg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/svd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/solar/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maryeon/whiten_mtd/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /models/dolg/dolg_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnet 3 | DEPTH: 101 4 | NUM_CLASSES: 81313 5 | HEADS: 6 | IN_FEAT: 2048 7 | REDUCTION_DIM: 512 8 | MARGIN: 0.15 9 | SCALE: 30 10 | RESNET: 11 | TRANS_FUN: bottleneck_transform 12 | NUM_GROUPS: 1 13 | WIDTH_PER_GROUP: 64 14 | STRIDE_1X1: False 15 | BN: 16 | ZERO_INIT_FINAL_GAMMA: True -------------------------------------------------------------------------------- /models/gem_pooling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GeneralizedMeanPooling(nn.Module): 5 | def __init__(self, p): 6 | super().__init__() 7 | self.p = p 8 | 9 | def forward(self, x): 10 | if self.p != 1.: 11 | mean = x.clamp(min=1e-6).pow(self.p).mean(dim=(2,3)) 12 | return mean.pow(1./self.p) 13 | else: 14 | return x.mean(dim=(2,3)) -------------------------------------------------------------------------------- /models/delg/r50_delg_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnet 3 | DEPTH: 50 4 | NUM_CLASSES: 96264 5 | HEADS: 6 | IN_FEAT: 2048 7 | REDUCTION_DIM: 512 8 | MARGIN: 0.15 9 | SCALE: 30 10 | RESNET: 11 | TRANS_FUN: bottleneck_transform 12 | NUM_GROUPS: 1 13 | WIDTH_PER_GROUP: 64 14 | STRIDE_1X1: False 15 | BN: 16 | ZERO_INIT_FINAL_GAMMA: True -------------------------------------------------------------------------------- /models/delg/r101_delg_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnet 3 | DEPTH: 101 4 | NUM_CLASSES: 96264 5 | HEADS: 6 | IN_FEAT: 2048 7 | REDUCTION_DIM: 512 8 | MARGIN: 0.15 9 | SCALE: 30 10 | RESNET: 11 | TRANS_FUN: bottleneck_transform 12 | NUM_GROUPS: 1 13 | WIDTH_PER_GROUP: 64 14 | STRIDE_1X1: False 15 | BN: 16 | ZERO_INIT_FINAL_GAMMA: True 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2023.7.22 2 | charset-normalizer==3.2.0 3 | idna==3.4 4 | numpy==1.24.4 5 | pandas==2.0.3 6 | Pillow==10.0.0 7 | python-dateutil==2.8.2 8 | pytz==2023.3 9 | PyYAML==6.0.1 10 | requests==2.31.0 11 | six==1.16.0 12 | torch==1.11.0+cu113 13 | torchaudio==0.11.0+cu113 14 | torchvision==0.12.0+cu113 15 | typing_extensions==4.7.1 16 | tzdata==2023.3 17 | urllib3==2.0.4 18 | yacs==0.1.8 19 | -------------------------------------------------------------------------------- /datasets/gld.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | class ImageDataset(torch.utils.data.dataset.Dataset): 7 | def __init__(self, root_path, img_id_list, transform=None): 8 | super().__init__() 9 | self.root_path = root_path 10 | self.img_id_list = img_id_list 11 | 12 | self.t = transform 13 | 14 | def __getitem__(self, i): 15 | img_id = self.img_id_list[i] 16 | img_path = os.path.join( 17 | self.root_path, 18 | img_id[0], img_id[1], img_id[2], 19 | img_id+".jpg" 20 | ) 21 | 22 | img = Image.open(img_path) 23 | img = img.convert("RGB") 24 | if self.t is not None: 25 | img = self.t(img) 26 | 27 | return img, img_id 28 | 29 | def __len__(self): 30 | return len(self.img_id_list) -------------------------------------------------------------------------------- /models/delg/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | from .config import cfg 4 | from .resnet import ResNet 5 | 6 | 7 | class R101_DELG(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | cfg.merge_from_file(os.path.join(os.path.dirname(__file__), "r101_delg_config.yaml")) 11 | cfg.freeze() 12 | self.globalmodel = ResNet() 13 | self.embed_dim = cfg.MODEL.HEADS.REDUCTION_DIM 14 | 15 | def forward(self, x): 16 | global_feature, _ = self.globalmodel(x) 17 | 18 | return global_feature 19 | 20 | 21 | class R50_DELG(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | cfg.merge_from_file(os.path.join(os.path.dirname(__file__), "r50_delg_config.yaml")) 25 | cfg.freeze() 26 | self.globalmodel = ResNet() 27 | self.embed_dim = cfg.MODEL.HEADS.REDUCTION_DIM 28 | 29 | def forward(self, x): 30 | global_feature, _ = self.globalmodel(x) 31 | 32 | return global_feature -------------------------------------------------------------------------------- /config/svd.yaml: -------------------------------------------------------------------------------- 1 | all_video_id: /path/to/svd/metadata-release/all-video-id 2 | labeled_id: /path/to/svd/metadata-release/labeled-data-id 3 | unlabeled_id: /path/to/svd/metadata-release/unlabeled-data-id 4 | query_id: /path/to/svd/metadata-release/query-id 5 | groundtruth: /path/to/svd/metadata-release/groundtruth 6 | train_groundtruth: /path/to/svd/metadata-release/train_groundtruth 7 | test_groundtruth: /path/to/svd/metadata-release/test_groundtruth 8 | 9 | video_root_paths: 10 | - /path/to/svd/query 11 | - /path/to/svd/labeled-0 12 | - /path/to/svd/labeled-1 13 | - /path/to/svd/unlabeled-0 14 | - /path/to/svd/unlabeled-1 15 | - /path/to/svd/unlabeled-2 16 | - /path/to/svd/unlabeled-3 17 | - /path/to/svd/unlabeled-4 18 | - /path/to/svd/unlabeled-5 19 | - /path/to/svd/unlabeled-6 20 | - /path/to/svd/unlabeled-7 21 | - /path/to/svd/unlabeled-8 22 | - /path/to/svd/unlabeled-9 23 | - /path/to/svd/unlabeled-10 24 | - /path/to/svd/unlabeled-11 25 | - /path/to/svd/unlabeled-12 26 | - /path/to/svd/unlabeled-13 27 | - /path/to/svd/unlabeled-14 28 | - /path/to/svd/unlabeled-15 29 | - /path/to/svd/unlabeled-16 30 | - /path/to/svd/unlabeled-17 31 | - /path/to/svd/unlabeled-18 32 | - /path/to/svd/unlabeled-19 33 | - /path/to/svd/unlabeled-20 34 | - /path/to/svd/unlabeled-21 35 | - /path/to/svd/unlabeled-22 36 | - /path/to/svd/unlabeled-23 37 | - /path/to/svd/unlabeled-24 38 | - /path/to/svd/unlabeled-25 39 | - /path/to/svd/unlabeled-26 40 | frame_root_path: /path/to/svd/frames 41 | frame_count_file: /path/to/svd/num_frames 42 | feature_root_path: /path/to/svd/features 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whiten_MTD 2 | Official repository of paper "Let All be Whitened: Multi-teacher Distillation for Efficient Visual Retrieval" accepted by AAAI 2024. 3 | ![teaser](assets/teaser.png) 4 | 5 | ## Prepare environment 6 | Create a conda virtual environment and install required packages: 7 | ```shell 8 | conda create -n whiten_mtd python=3.8 9 | conda activate whiten_mtd 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Prepare dataset 14 | We use [Google Landmark V2 (GLDv2)](https://github.com/cvdfoundation/google-landmark) and [SVD](https://svdbase.github.io/) as training datasets which can be downloaded following their official repositories. GLDv2 can be used for training by passing its root path to the argument of script ```gld_pca_learn.py``` and ```gld_distill.py```. To train on SVD, configuration file ```svd.yaml``` in ```config``` directory should be correspondingly modified. 15 | 16 | Other two datasets [Roxford5k and RParis6k](http://cmp.felk.cvut.cz/revisitop/) should also be downloaded for evaluation. 17 | 18 | ## Evaluation 19 | Pretrained weights of teacher models and their PCA-Whitening layer can be downloaded from [here](https://drive.google.com/drive/folders/1-9BOzGBCNY6FrGCmpefSfispWHfiVMOd?usp=sharing). 20 | ### Instance image retrieval 21 | Pretrained student model checkpoints can be downloaded from the links below: 22 | 23 | | Teachers | Student | Links | 24 | | :-: | :-: | :-: | 25 | | GeM, AP-GeM, SOLAR | R18 | [rg_rag_rs_to_r18_ep200](https://drive.google.com/file/d/1qLp_AoI5SRNs9AV3o8SzJchyXwBFSsDL/view?usp=sharing) | 26 | | GeM, AP-GeM, SOLAR | R34 | [rg_rag_rs_to_r34_ep200](https://drive.google.com/file/d/1wsPIgGnXw6TPmVDtyFLYXSFF1YzRePCE/view?usp=sharing) | 27 | | DOLG, DELG | R18 | [ro_re_to_r18_ep3k](https://drive.google.com/file/d/1TDi9WelEu7Ks5fAOIMQftSroZzDXzKaM/view?usp=sharing) | 28 | | DOLG, DELG | R34 | [ro_re_to_r34_ep3k](https://drive.google.com/file/d/1XVnURGdqdmJ1GiMswNEgHLji_NCpdBCV/view?usp=sharing) | 29 | 30 | To perform evaluation using our pretrained weights: 31 | ```shell 32 | python oxford_paris_eval.py -a resnet18/34 -r PATH_TO_CHECKPOINT -dp PATH_TO_DATASET --embed_dim 512 -ms -p 3 33 | ``` 34 | 35 | ### Video retrieval 36 | Pretrained student model checkpoints can be downloaded from the links below: 37 | 38 | | Teachers | Student | Links | 39 | | :-: | :-: | :-: | 40 | | MoCoV3, BarlowTwins | R18 | [mc_bt_to_r18_ep3k](https://drive.google.com/file/d/1yKv2-TGHwaAlQOugLpiOjo64TpHaM176/view?usp=sharing) | 41 | | MoCoV3, BarlowTwins | R34 | [mc_bt_to_r34_ep3k](https://drive.google.com/file/d/1GrkzoeT8QUAqY6B7Jio6TqkDa0rsb7PN/view?usp=sharing) | 42 | 43 | To perform evaluation using our pretrained weights: 44 | ```shell 45 | python svd_eval.py -a resnet18/34 -dm config/svd.yaml --sim_fn cf -r PATH_TO_CHECKPOINT --embed_dim 512 46 | ``` 47 | 48 | ## Training 49 | We train all the models on a server with 8 16G V100 and batch size of 256. Run the following with our default settings to train your own models: 50 | 51 | - On GLDv2: 52 | ```shell 53 | python gld_distill.py -a resnet18/34 -ts resnet101_delg resnet101_dolg -c PATH_TO_SAVE_CHECKPOINTS --gld_root_path PATH_TO_DATASET 54 | ``` 55 | - On SVD: 56 | ```shell 57 | python svd_distill.py -a resnet18/34 -ts mocov3 barlowtwins -c PATH_TO_SAVE_CHECKPOINTS -dm config/svd.yaml 58 | ``` 59 | -------------------------------------------------------------------------------- /models/delg/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Configuration file (powered by YACS).""" 7 | 8 | import argparse 9 | import os 10 | import sys 11 | 12 | from yacs.config import CfgNode as CfgNode 13 | 14 | 15 | # Global config object 16 | _C = CfgNode() 17 | 18 | # Example usage: 19 | # from core.config import cfg 20 | cfg = _C 21 | 22 | 23 | # ------------------------------------------------------------------------------------ # 24 | # Model options 25 | # ------------------------------------------------------------------------------------ # 26 | _C.MODEL = CfgNode() 27 | 28 | # Model type 29 | _C.MODEL.TYPE = "" 30 | 31 | # Number of weight layers 32 | _C.MODEL.DEPTH = 0 33 | 34 | # Number of classes 35 | _C.MODEL.NUM_CLASSES = 10 36 | 37 | # Loss function (see pycls/models/loss.py for options) 38 | _C.MODEL.LOSSES = CfgNode() 39 | _C.MODEL.LOSSES.NAME = "cross_entropy" 40 | 41 | # ------------------------------------------------------------------------------------ # 42 | # Heads options 43 | # ------------------------------------------------------------------------------------ 44 | _C.MODEL.HEADS = CfgNode() 45 | _C.MODEL.HEADS.NAME = "LinearHead" 46 | # Normalization method for the convolution layers. 47 | # Number of identity 48 | _C.MODEL.HEADS.NUM_CLASSES = 1000 49 | # Input feature dimension 50 | _C.MODEL.HEADS.IN_FEAT = 2048 51 | # Reduction dimension in head 52 | _C.MODEL.HEADS.REDUCTION_DIM = 512 53 | # Pooling layer type 54 | _C.MODEL.HEADS.POOL_LAYER = "avgpool" 55 | # Classification layer type 56 | _C.MODEL.HEADS.CLS_LAYER = "linear" 57 | # Margin and Scale for margin-based classification layer 58 | _C.MODEL.HEADS.MARGIN = 0.15 59 | _C.MODEL.HEADS.SCALE = 128 60 | 61 | 62 | # ------------------------------------------------------------------------------------ # 63 | # ResNet options 64 | # ------------------------------------------------------------------------------------ # 65 | _C.RESNET = CfgNode() 66 | 67 | # Transformation function (see pycls/models/resnet.py for options) 68 | _C.RESNET.TRANS_FUN = "basic_transform" 69 | 70 | # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) 71 | _C.RESNET.NUM_GROUPS = 1 72 | 73 | # Width of each group (64 -> ResNet; 4 -> ResNeXt) 74 | _C.RESNET.WIDTH_PER_GROUP = 64 75 | 76 | # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) 77 | _C.RESNET.STRIDE_1X1 = True 78 | 79 | 80 | # ------------------------------------------------------------------------------------ # 81 | # Batch norm options 82 | # ------------------------------------------------------------------------------------ # 83 | _C.BN = CfgNode() 84 | 85 | # BN epsilon 86 | _C.BN.EPS = 1e-5 87 | 88 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 89 | _C.BN.MOM = 0.1 90 | 91 | # Precise BN stats 92 | _C.BN.USE_PRECISE_STATS = False 93 | _C.BN.NUM_SAMPLES_PRECISE = 1024 94 | 95 | # Initialize the gamma of the final BN of each block to zero 96 | _C.BN.ZERO_INIT_FINAL_GAMMA = False 97 | 98 | # Use a different weight decay for BN layers 99 | _C.BN.USE_CUSTOM_WEIGHT_DECAY = False 100 | _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 101 | 102 | 103 | # ------------------------------------------------------------------------------------ # 104 | # Memory options 105 | # ------------------------------------------------------------------------------------ # 106 | _C.MEM = CfgNode() 107 | 108 | # Perform ReLU inplace 109 | _C.MEM.RELU_INPLACE = True -------------------------------------------------------------------------------- /models/pca_layer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from utils import PCA 5 | 6 | 7 | def mocov3_pca_layer(path_to_pretrained_weights, *args, **kwargs): 8 | pretrained_weights = os.path.join(path_to_pretrained_weights, "pca_weights/mocov3_pca_512d_svd_224x224_p1.pt") 9 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 10 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 11 | pca_layer.load_state_dict(checkpoint) 12 | 13 | return pca_layer 14 | 15 | 16 | def barlowtwins_pca_layer(path_to_pretrained_weights, *args, **kwargs): 17 | pretrained_weights = os.path.join(path_to_pretrained_weights, "pca_weights/barlowtwins_pca_512d_svd_224x224_p1.pt") 18 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 19 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 20 | pca_layer.load_state_dict(checkpoint) 21 | 22 | return pca_layer 23 | 24 | 25 | def resnet101_gem_pca_layer(path_to_pretrained_weights, embed_dim=512): 26 | pretrained_weights = os.path.join(path_to_pretrained_weights, f"pca_weights/resnet101_gem_pca_{embed_dim}d_gldv2_512x512_p3_randcrop.pt") 27 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 28 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 29 | pca_layer.load_state_dict(checkpoint) 30 | 31 | return pca_layer 32 | 33 | 34 | def resnet101_ap_gem_pca_layer(path_to_pretrained_weights, embed_dim=512): 35 | pretrained_weights = os.path.join(path_to_pretrained_weights, f"pca_weights/resnet101_ap_gem_pca_{embed_dim}d_gldv2_512x512_p3_randcrop.pt") 36 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 37 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 38 | pca_layer.load_state_dict(checkpoint) 39 | 40 | return pca_layer 41 | 42 | 43 | def resnet101_solar_pca_layer(path_to_pretrained_weights, embed_dim=512): 44 | pretrained_weights = os.path.join(path_to_pretrained_weights, f"pca_weights/resnet101_solar_pca_{embed_dim}d_gldv2_512x512_p3_randcrop.pt") 45 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 46 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 47 | pca_layer.load_state_dict(checkpoint) 48 | 49 | return pca_layer 50 | 51 | 52 | def resnet101_delg_pca_layer(path_to_pretrained_weights, embed_dim=512): 53 | pretrained_weights = os.path.join(path_to_pretrained_weights, f"pca_weights/resnet101_delg_pca_{embed_dim}d_gldv2_512x512_p3_randcrop.pt") 54 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 55 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 56 | pca_layer.load_state_dict(checkpoint) 57 | 58 | return pca_layer 59 | 60 | 61 | def resnet101_dolg_pca_layer(path_to_pretrained_weights, embed_dim=512): 62 | pretrained_weights = os.path.join(path_to_pretrained_weights, f"pca_weights/resnet101_dolg_pca_{embed_dim}d_gldv2_512x512_p3_global_randcrop.pt") 63 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 64 | pca_layer = PCA(dim1=checkpoint["dim1"].item(), dim2=checkpoint["dim2"].item()) 65 | pca_layer.load_state_dict(checkpoint) 66 | 67 | return pca_layer 68 | 69 | 70 | pca_layers = { 71 | "mocov3": mocov3_pca_layer, 72 | "barlowtwins": barlowtwins_pca_layer, 73 | "resnet101_gem": resnet101_gem_pca_layer, 74 | "resnet101_ap_gem": resnet101_ap_gem_pca_layer, 75 | "resnet101_solar": resnet101_solar_pca_layer, 76 | "resnet101_delg": resnet101_delg_pca_layer, 77 | "resnet101_dolg": resnet101_dolg_pca_layer 78 | } -------------------------------------------------------------------------------- /models/dolg/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Configuration file (powered by YACS).""" 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | from yacs.config import CfgNode as CfgNode 15 | 16 | 17 | # Global config object 18 | _C = CfgNode() 19 | 20 | # Example usage: 21 | # from core.config import cfg 22 | cfg = _C 23 | 24 | 25 | # ------------------------------------------------------------------------------------ # 26 | # Model options 27 | # ------------------------------------------------------------------------------------ # 28 | _C.MODEL = CfgNode() 29 | 30 | # Model type 31 | _C.MODEL.TYPE = "" 32 | 33 | # Number of weight layers 34 | _C.MODEL.DEPTH = 0 35 | 36 | # Channels 37 | _C.MODEL.S4_DIM = 2048 38 | _C.MODEL.S3_DIM = 1024 39 | _C.MODEL.S2_DIM = 512 40 | 41 | # ASPP 42 | _C.MODEL.WITH_MA = False 43 | 44 | # Number of classes 45 | _C.MODEL.NUM_CLASSES = 10 46 | 47 | # Loss function (see pycls/models/loss.py for options) 48 | _C.MODEL.LOSSES = CfgNode() 49 | _C.MODEL.LOSSES.NAME = "cross_entropy" 50 | 51 | # ------------------------------------------------------------------------------------ # 52 | # Heads options 53 | # ------------------------------------------------------------------------------------ 54 | _C.MODEL.HEADS = CfgNode() 55 | _C.MODEL.HEADS.NAME = "LinearHead" 56 | # Normalization method for the convolution layers. 57 | # Number of identity 58 | _C.MODEL.HEADS.NUM_CLASSES = 1000 59 | # Input feature dimension 60 | _C.MODEL.HEADS.IN_FEAT = 2048 61 | # Reduction dimension in head 62 | _C.MODEL.HEADS.REDUCTION_DIM = 512 63 | # Pooling layer type 64 | _C.MODEL.HEADS.POOL_LAYER = "avgpool" 65 | # Classification layer type 66 | _C.MODEL.HEADS.CLS_LAYER = "linear" 67 | # Margin and Scale for margin-based classification layer 68 | _C.MODEL.HEADS.MARGIN = 0.15 69 | _C.MODEL.HEADS.SCALE = 128 70 | 71 | 72 | # ------------------------------------------------------------------------------------ # 73 | # ResNet options 74 | # ------------------------------------------------------------------------------------ # 75 | _C.RESNET = CfgNode() 76 | 77 | # Transformation function (see pycls/models/resnet.py for options) 78 | _C.RESNET.TRANS_FUN = "basic_transform" 79 | 80 | # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) 81 | _C.RESNET.NUM_GROUPS = 1 82 | 83 | # Width of each group (64 -> ResNet; 4 -> ResNeXt) 84 | _C.RESNET.WIDTH_PER_GROUP = 64 85 | 86 | # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) 87 | _C.RESNET.STRIDE_1X1 = True 88 | 89 | # ------------------------------------------------------------------------------------ # 90 | # Batch norm options 91 | # ------------------------------------------------------------------------------------ # 92 | _C.BN = CfgNode() 93 | 94 | # BN epsilon 95 | _C.BN.EPS = 1e-5 96 | 97 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 98 | _C.BN.MOM = 0.1 99 | 100 | # Precise BN stats 101 | _C.BN.USE_PRECISE_STATS = False 102 | _C.BN.NUM_SAMPLES_PRECISE = 1024 103 | 104 | # Initialize the gamma of the final BN of each block to zero 105 | _C.BN.ZERO_INIT_FINAL_GAMMA = False 106 | 107 | # Use a different weight decay for BN layers 108 | _C.BN.USE_CUSTOM_WEIGHT_DECAY = False 109 | _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 110 | 111 | 112 | # ------------------------------------------------------------------------------------ # 113 | # Memory options 114 | # ------------------------------------------------------------------------------------ # 115 | _C.MEM = CfgNode() 116 | 117 | # Perform ReLU inplace 118 | _C.MEM.RELU_INPLACE = True -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /checkpoints 3 | /pretrained_models 4 | .gitignore 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /models/distill_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | 8 | from .teachers import teacher_models 9 | from .pca_layer import pca_layers 10 | from .resnet import resnet18, resnet34 11 | from .gem_pooling import GeneralizedMeanPooling 12 | 13 | 14 | __all__ = [ 15 | "DistillModel", 16 | "MultiTeacher" 17 | ] 18 | 19 | archs = { 20 | "resnet18": resnet18, 21 | "resnet34": resnet34 22 | } 23 | 24 | 25 | class MultiTeacher(nn.Module): 26 | def __init__(self, path_to_pretrained_weights, *teachers, p=3., embed_dim=512): 27 | super().__init__() 28 | self.teachers = teachers 29 | encoders = list() 30 | for teacher in self.teachers: 31 | encoders.append(teacher_models[teacher](path_to_pretrained_weights, pretrained=True, gem_p=p)) 32 | self.encoders = nn.ModuleList(encoders) 33 | self.embed_dims = [encoder.embed_dim for encoder in self.encoders] 34 | 35 | norm_layers = list() 36 | for teacher in self.teachers: 37 | norm_layers.append(pca_layers[teacher](path_to_pretrained_weights, embed_dim=embed_dim)) 38 | self.norm_layers = nn.ModuleList(norm_layers) 39 | self.embed_dims = [norm_layer.dim2 for norm_layer in self.norm_layers] 40 | 41 | def forward(self, x): 42 | out = list() 43 | for teacher, encoder in zip(self.teachers, self.encoders): 44 | if teacher.endswith("delg") or teacher.endswith("dolg"): 45 | out.append(encoder(x[:, (2,1,0)])) 46 | else: 47 | out.append(encoder(x)) 48 | 49 | out = [nn.functional.normalize(o, p=2, dim=-1) for o in out] 50 | out = [norm_layer(o) for norm_layer, o in zip(self.norm_layers, out)] 51 | 52 | return out 53 | 54 | 55 | class MultiTeacherDistillModel(nn.Module): 56 | def __init__(self, args): 57 | super().__init__() 58 | self.base_encoder = archs[args.arch](pretrained=True, num_classes=args.embed_dim) 59 | self.base_encoder.avgpool = GeneralizedMeanPooling(args.p) 60 | self.base_encoder.fc = nn.Linear(self.base_encoder.embed_dim, args.embed_dim) 61 | self.embed_dim = self.base_encoder.embed_dim 62 | self.teacher_encoders = MultiTeacher(args.path_to_pretrained_weights, *args.teachers, p=args.p, embed_dim=args.embed_dim) 63 | for param in self.teacher_encoders.parameters(): 64 | param.requires_grad = False 65 | 66 | def forward(self, x1, x2): 67 | stu1 = self.base_encoder(x1) 68 | stu2 = self.base_encoder(x2) 69 | 70 | with torch.no_grad(): 71 | tch1 = self.teacher_encoders(x1) 72 | tch2 = self.teacher_encoders(x2) 73 | 74 | return stu1, stu2, tch1, tch2 75 | 76 | # teacher models's weights are never trained 77 | def train(self, mode=True): 78 | self.teacher_encoders.train(False) 79 | self.base_encoder.train(mode) 80 | return self 81 | 82 | # overwrite of built-in function 83 | def state_dict(self, *args, destination=None, prefix='', keep_vars=False): 84 | if len(args) > 0: 85 | if destination is None: 86 | destination = args[0] 87 | if len(args) > 1 and prefix == '': 88 | prefix = args[1] 89 | if len(args) > 2 and keep_vars is False: 90 | keep_vars = args[2] 91 | 92 | if destination is None: 93 | destination = OrderedDict() 94 | destination._metadata = OrderedDict() 95 | 96 | local_metadata = dict(version=self._version) 97 | if hasattr(destination, "_metadata"): 98 | destination._metadata[prefix[:-1]] = local_metadata 99 | 100 | self._save_to_state_dict(destination, prefix, keep_vars) 101 | for name, module in self._modules.items(): 102 | # only save weights of student model 103 | if module is not None and name not in ["teacher_encoders"]: 104 | module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) 105 | for hook in self._state_dict_hooks.values(): 106 | hook_result = hook(self, destination, prefix, local_metadata) 107 | if hook_result is not None: 108 | destination = hook_result 109 | 110 | return destination 111 | 112 | 113 | if __name__ == "__main__": 114 | pass 115 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import yaml 6 | import logging 7 | import shutil 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | 13 | 14 | def load_config(config_file): 15 | assert os.path.exists(config_file), f"Config file {config_file} not found!" 16 | 17 | with open(config_file, "r") as f: 18 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 19 | 20 | return cfg 21 | 22 | 23 | def is_dist_avail_and_initialized(): 24 | if not dist.is_available(): 25 | return False 26 | if not dist.is_initialized(): 27 | return False 28 | return True 29 | 30 | 31 | def get_world_size(): 32 | if not is_dist_avail_and_initialized(): 33 | return 1 34 | return dist.get_world_size() 35 | 36 | 37 | def get_rank(): 38 | if not is_dist_avail_and_initialized(): 39 | return 0 40 | return dist.get_rank() 41 | 42 | 43 | def has_batchnorms(model): 44 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 45 | for name, module in model.named_modules(): 46 | if isinstance(module, bn_types): 47 | return True 48 | return False 49 | 50 | 51 | def setup_logger(log_path=None, log_level=logging.INFO): 52 | logger = logging.root 53 | logger.setLevel(log_level) 54 | 55 | log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 56 | formatter = logging.Formatter(log_format) 57 | 58 | if log_path is not None: 59 | log_file = os.path.join(log_path, "log.log") 60 | os.makedirs(log_path, exist_ok=True) 61 | fh = logging.FileHandler(log_file, mode="w") 62 | fh.setLevel(log_level) 63 | fh.setFormatter(formatter) 64 | logger.addHandler(fh) 65 | 66 | ch = logging.StreamHandler() 67 | ch.setLevel(log_level) 68 | ch.setFormatter(formatter) 69 | logger.addHandler(ch) 70 | 71 | return 72 | 73 | 74 | def load_config(config_file): 75 | assert os.path.exists(config_file), f"Config file {config_file} not found!" 76 | 77 | with open(config_file, "r") as f: 78 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 79 | 80 | return cfg 81 | 82 | 83 | def save_checkpoint(state, is_best, path, filename='checkpoint.pt'): 84 | assert path is not None, f"Checkpoint save path should not be None type." 85 | os.makedirs(path, exist_ok=True) 86 | torch.save(state, os.path.join(path, filename)) 87 | if is_best: 88 | shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'model_best.pt')) 89 | 90 | 91 | class RGB2BGR(object): 92 | def __call__(self, x): 93 | return x[(2,1,0),] 94 | 95 | 96 | class PCA(nn.Module): 97 | """ 98 | Class to compute and apply PCA. 99 | """ 100 | def __init__(self, dim1=0, dim2=0, whit=0.5): 101 | super().__init__() 102 | self.register_buffer("dim1", torch.tensor(dim1, dtype=torch.long)) 103 | self.register_buffer("dim2", torch.tensor(dim2, dtype=torch.long)) 104 | self.register_buffer("whit", torch.tensor(whit, dtype=torch.float32)) 105 | self.register_buffer("d", torch.zeros(self.dim1, dtype=torch.float32)) 106 | self.register_buffer("v", torch.zeros(self.dim1, self.dim1, dtype=torch.float32)) 107 | self.register_buffer("n_0", torch.tensor(0, dtype=torch.long)) 108 | self.register_buffer("mean", torch.zeros(1, self.dim1, dtype=torch.float32)) 109 | self.register_buffer("dvt", torch.zeros(self.dim2, self.dim1, dtype=torch.float32)) 110 | 111 | def train_pca(self, x): 112 | """ 113 | Takes a covariance matrix (torch.Tensor) as input. 114 | """ 115 | x_mean = x.mean(dim=0, keepdim=True) 116 | self.mean = x_mean 117 | x -= x_mean 118 | cov = x.t().mm(x) / x.size(0) 119 | 120 | d, v = torch.linalg.eigh(cov) 121 | 122 | self.d.copy_(d) 123 | self.v.copy_(v) 124 | 125 | eps = d.max() * 1e-5 126 | n_0 = (d < eps).sum() 127 | if n_0 > 0: 128 | d[d < eps] = eps 129 | 130 | self.n_0 = n_0 131 | 132 | # total energy 133 | totenergy = d.sum() 134 | 135 | # sort eigenvectors with eigenvalues order 136 | idx = torch.argsort(d, descending=True)[:self.dim2] 137 | d = d[idx] 138 | v = v[:, idx] 139 | 140 | logger = logging.getLogger("pca") 141 | logger.info(f"keeping {d.sum() / totenergy * 100.0:.2f} % of the energy") 142 | 143 | # for the whitening 144 | d = torch.diag(1. / d**self.whit) 145 | 146 | # principal components 147 | self.dvt = d @ v.T 148 | 149 | def forward(self, x): 150 | x -= self.mean 151 | return torch.mm(self.dvt, x.transpose(0, 1)).transpose(0, 1) -------------------------------------------------------------------------------- /models/solar/networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import numpy as np 5 | 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torchvision import models 11 | 12 | 13 | #################################################################################################### 14 | ########################################## Functions ############################################### 15 | #################################################################################################### 16 | 17 | 18 | ## Kaiming weight initialisation 19 | def weights_init(module): 20 | if isinstance(module, nn.ReLU): 21 | pass 22 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d): 23 | nn.init.kaiming_normal_(module.weight.data) 24 | nn.init.constant_(module.bias.data, 0.0) 25 | elif isinstance(module, nn.BatchNorm2d): 26 | pass 27 | #nn.init.kaiming_normal_(module.weight.data) 28 | #nn.init.constant_(module.bias.data, 0.0) 29 | 30 | def constant_init(module): 31 | if isinstance(module, nn.ReLU): 32 | pass 33 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d): 34 | nn.init.constant_(module.weight.data, 0.0) 35 | nn.init.constant_(module.bias.data, 0.0) 36 | elif isinstance(module, nn.BatchNorm2d): 37 | pass 38 | #nn.init.kaiming_normal_(module.weight.data) 39 | 40 | 41 | 42 | #################################################################################################### 43 | ########################################## Networks ############################################### 44 | #################################################################################################### 45 | 46 | class SOABlock(nn.Module): 47 | def __init__(self, in_ch, k): 48 | super(SOABlock, self).__init__() 49 | 50 | self.in_ch = in_ch 51 | self.out_ch = in_ch 52 | self.mid_ch = in_ch // k 53 | 54 | self.f = nn.Sequential( 55 | nn.Conv2d(self.in_ch, self.mid_ch, (1, 1), (1, 1)), 56 | nn.BatchNorm2d(self.mid_ch), 57 | nn.ReLU()) 58 | self.g = nn.Sequential( 59 | nn.Conv2d(self.in_ch, self.mid_ch, (1, 1), (1, 1)), 60 | nn.BatchNorm2d(self.mid_ch), 61 | nn.ReLU()) 62 | self.h = nn.Conv2d(self.in_ch, self.mid_ch, (1, 1), (1, 1)) 63 | self.v =nn.Conv2d(self.mid_ch, self.out_ch, (1, 1), (1, 1)) 64 | 65 | self.softmax = nn.Softmax(dim=-1) 66 | 67 | for conv in [self.f, self.g, self.h]: #, self.v]: 68 | conv.apply(weights_init) 69 | 70 | self.v.apply(constant_init) 71 | 72 | 73 | def forward(self, x, vis_mode=False): 74 | B, C, H, W = x.shape 75 | 76 | f_x = self.f(x).view(B, self.mid_ch, H * W) # B * mid_ch * N, where N = H*W 77 | g_x = self.g(x).view(B, self.mid_ch, H * W) # B * mid_ch * N, where N = H*W 78 | h_x = self.h(x).view(B, self.mid_ch, H * W) # B * mid_ch * N, where N = H*W 79 | 80 | z = torch.bmm(f_x.permute(0, 2, 1), g_x) # B * N * N, where N = H*W 81 | 82 | if vis_mode: 83 | # for visualisation only 84 | attn = self.softmax((self.mid_ch ** -.75) * z) 85 | else: 86 | attn = self.softmax((self.mid_ch ** -.50) * z) 87 | 88 | z = torch.bmm(attn, h_x.permute(0, 2, 1)) # B * N * mid_ch, where N = H*W 89 | z = z.permute(0, 2, 1).view(B, self.mid_ch, H, W) # B * mid_ch * H * W 90 | 91 | z = self.v(z) 92 | z = z + x 93 | 94 | return z, attn 95 | 96 | 97 | class ResNetSOAs(nn.Module): 98 | def __init__(self, architecture='resnet101', soa_layers='45'): 99 | super(ResNetSOAs, self).__init__() 100 | 101 | base_model = vars(models)[architecture](pretrained=False) 102 | last_feat_in = base_model.inplanes 103 | base_model = nn.Sequential(*list(base_model.children())[:-2]) 104 | 105 | res_blocks = list(base_model.children()) 106 | 107 | self.conv1 = nn.Sequential(*res_blocks[0:2]) 108 | self.conv2_x = nn.Sequential(*res_blocks[2:5]) 109 | self.conv3_x = res_blocks[5] 110 | self.conv4_x = res_blocks[6] 111 | self.conv5_x = res_blocks[7] 112 | 113 | self.soa_layers = soa_layers 114 | if '4' in self.soa_layers: 115 | self.soa4 = SOABlock(in_ch=last_feat_in // 2, k=4) 116 | if '5' in self.soa_layers: 117 | self.soa5 = SOABlock(in_ch=last_feat_in, k=2) 118 | 119 | self.embed_dim = last_feat_in 120 | 121 | def forward(self, x, mode='test'): 122 | with torch.no_grad(): 123 | x = self.conv1(x) 124 | x = self.conv2_x(x) 125 | x = self.conv3_x(x) 126 | x = self.conv4_x(x) 127 | 128 | # start SOA blocks 129 | if '4' in self.soa_layers: 130 | x, soa_m2 = self.soa4(x, mode == 'draw') 131 | 132 | x = self.conv5_x(x) 133 | if '5' in self.soa_layers: 134 | x, soa_m1 = self.soa5(x, mode == 'draw') 135 | 136 | if mode == 'draw': 137 | return x, soa_m2, soa_m1 138 | 139 | return x 140 | -------------------------------------------------------------------------------- /models/delg/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | from core.config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA 26 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 27 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 28 | m.bias.data.zero_() 29 | elif isinstance(m, nn.Linear): 30 | m.weight.data.normal_(mean=0.0, std=0.01) 31 | m.bias.data.zero_() 32 | 33 | 34 | def init_weights_classifier(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Linear') != -1: 37 | nn.init.normal_(m.weight, std=0.001) 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0.0) 40 | elif classname.find("Arcface") != -1 or classname.find("Circle") != -1: 41 | nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) 42 | 43 | 44 | @torch.no_grad() 45 | def compute_precise_bn_stats(model, loader): 46 | """Computes precise BN stats on training data.""" 47 | # Compute the number of minibatches to use 48 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 49 | # Retrieve the BN layers 50 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 51 | # Initialize stats storage 52 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 53 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 54 | # Remember momentum values 55 | moms = [bn.momentum for bn in bns] 56 | # Disable momentum 57 | for bn in bns: 58 | bn.momentum = 1.0 59 | # Accumulate the stats across the data samples 60 | for inputs, _labels in itertools.islice(loader, num_iter): 61 | model(inputs.cuda()) 62 | # Accumulate the stats for each BN layer 63 | for i, bn in enumerate(bns): 64 | m, v = bn.running_mean, bn.running_var 65 | sqs[i] += (v + m * m) / num_iter 66 | mus[i] += m / num_iter 67 | # Set the stats and restore momentum values 68 | for i, bn in enumerate(bns): 69 | bn.running_var = sqs[i] - mus[i] * mus[i] 70 | bn.running_mean = mus[i] 71 | bn.momentum = moms[i] 72 | 73 | 74 | def reset_bn_stats(model): 75 | """Resets running BN stats.""" 76 | for m in model.modules(): 77 | if isinstance(m, torch.nn.BatchNorm2d): 78 | m.reset_running_stats() 79 | 80 | 81 | def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False): 82 | """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts).""" 83 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 84 | h = (h + 2 * padding - k) // stride + 1 85 | w = (w + 2 * padding - k) // stride + 1 86 | flops += k * k * w_in * w_out * h * w // groups 87 | params += k * k * w_in * w_out // groups 88 | flops += w_out if bias else 0 89 | params += w_out if bias else 0 90 | acts += w_out * h * w 91 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 92 | 93 | 94 | def complexity_batchnorm2d(cx, w_in): 95 | """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts).""" 96 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 97 | params += 2 * w_in 98 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 99 | 100 | 101 | def complexity_maxpool2d(cx, k, stride, padding): 102 | """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts).""" 103 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 104 | h = (h + 2 * padding - k) // stride + 1 105 | w = (w + 2 * padding - k) // stride + 1 106 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 107 | 108 | 109 | def complexity(model): 110 | """Compute model complexity (model can be model instance or model class).""" 111 | size = cfg.TRAIN.IM_SIZE 112 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 113 | cx = model.complexity(cx) 114 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 115 | 116 | 117 | def drop_connect(x, drop_ratio): 118 | """Drop connect (adapted from DARTS).""" 119 | keep_ratio = 1.0 - drop_ratio 120 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 121 | mask.bernoulli_(keep_ratio) 122 | x.div_(keep_ratio) 123 | x.mul_(mask) 124 | return x 125 | 126 | 127 | def get_flat_weights(model): 128 | """Gets all model weights as a single flat vector.""" 129 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 130 | 131 | 132 | def set_flat_weights(model, flat_weights): 133 | """Sets all model weights from a single flat vector.""" 134 | k = 0 135 | for p in model.parameters(): 136 | n = p.data.numel() 137 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) 138 | k += n 139 | assert k == flat_weights.numel() 140 | 141 | 142 | def freeze_weights(model, freeze=[]): 143 | for name, child in model.module.named_children(): 144 | if name in freeze: 145 | for param in child.parameters(): 146 | param.requires_grad = False 147 | 148 | 149 | def unfreeze_weights(model, freeze=[]): 150 | for name, child in model.module.named_children(): 151 | if name in freeze: 152 | for param in child.parameters(): 153 | param.requires_grad = True 154 | 155 | -------------------------------------------------------------------------------- /models/dolg/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | from .config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA 26 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 27 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 28 | m.bias.data.zero_() 29 | elif isinstance(m, nn.Linear): 30 | m.weight.data.normal_(mean=0.0, std=0.01) 31 | m.bias.data.zero_() 32 | 33 | 34 | def init_weights_classifier(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Linear') != -1: 37 | nn.init.normal_(m.weight, std=0.001) 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0.0) 40 | elif classname.find("Arcface") != -1 or classname.find("Circle") != -1: 41 | nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) 42 | 43 | 44 | @torch.no_grad() 45 | def compute_precise_bn_stats(model, loader): 46 | """Computes precise BN stats on training data.""" 47 | # Compute the number of minibatches to use 48 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 49 | # Retrieve the BN layers 50 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 51 | # Initialize stats storage 52 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 53 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 54 | # Remember momentum values 55 | moms = [bn.momentum for bn in bns] 56 | # Disable momentum 57 | for bn in bns: 58 | bn.momentum = 1.0 59 | # Accumulate the stats across the data samples 60 | for inputs, _labels in itertools.islice(loader, num_iter): 61 | model(inputs.cuda()) 62 | # Accumulate the stats for each BN layer 63 | for i, bn in enumerate(bns): 64 | m, v = bn.running_mean, bn.running_var 65 | sqs[i] += (v + m * m) / num_iter 66 | mus[i] += m / num_iter 67 | # Set the stats and restore momentum values 68 | for i, bn in enumerate(bns): 69 | bn.running_var = sqs[i] - mus[i] * mus[i] 70 | bn.running_mean = mus[i] 71 | bn.momentum = moms[i] 72 | 73 | 74 | def reset_bn_stats(model): 75 | """Resets running BN stats.""" 76 | for m in model.modules(): 77 | if isinstance(m, torch.nn.BatchNorm2d): 78 | m.reset_running_stats() 79 | 80 | 81 | def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False): 82 | """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts).""" 83 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 84 | h = (h + 2 * padding - k) // stride + 1 85 | w = (w + 2 * padding - k) // stride + 1 86 | flops += k * k * w_in * w_out * h * w // groups 87 | params += k * k * w_in * w_out // groups 88 | flops += w_out if bias else 0 89 | params += w_out if bias else 0 90 | acts += w_out * h * w 91 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 92 | 93 | 94 | def complexity_batchnorm2d(cx, w_in): 95 | """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts).""" 96 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 97 | params += 2 * w_in 98 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 99 | 100 | 101 | def complexity_maxpool2d(cx, k, stride, padding): 102 | """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts).""" 103 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 104 | h = (h + 2 * padding - k) // stride + 1 105 | w = (w + 2 * padding - k) // stride + 1 106 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 107 | 108 | 109 | def complexity(model): 110 | """Compute model complexity (model can be model instance or model class).""" 111 | size = cfg.TRAIN.IM_SIZE 112 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 113 | cx = model.complexity(cx) 114 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 115 | 116 | 117 | def drop_connect(x, drop_ratio): 118 | """Drop connect (adapted from DARTS).""" 119 | keep_ratio = 1.0 - drop_ratio 120 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 121 | mask.bernoulli_(keep_ratio) 122 | x.div_(keep_ratio) 123 | x.mul_(mask) 124 | return x 125 | 126 | 127 | def get_flat_weights(model): 128 | """Gets all model weights as a single flat vector.""" 129 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 130 | 131 | 132 | def set_flat_weights(model, flat_weights): 133 | """Sets all model weights from a single flat vector.""" 134 | k = 0 135 | for p in model.parameters(): 136 | n = p.data.numel() 137 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) 138 | k += n 139 | assert k == flat_weights.numel() 140 | 141 | 142 | def freeze_weights(model, freeze=[]): 143 | 144 | for name, child in model.module.named_children(): 145 | if name in freeze: 146 | for param in child.parameters(): 147 | param.requires_grad = False 148 | 149 | 150 | 151 | def unfreeze_weights(model, freeze=[]): 152 | 153 | for name, child in model.module.named_children(): 154 | if name in freeze: 155 | for param in child.parameters(): 156 | param.requires_grad = True 157 | 158 | -------------------------------------------------------------------------------- /datasets/oxford_paris.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from PIL import Image, ImageFile 6 | 7 | 8 | class OxfordParisDataset(torch.utils.data.Dataset): 9 | def __init__(self, dir_main, dataset, split, transform=None, imsize=None): 10 | if dataset not in ['roxford5k', 'rparis6k', 'revisitop1m']: 11 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 12 | 13 | if dataset == 'roxford5k' or dataset == 'rparis6k': 14 | # loading imlist, qimlist, and gnd, in cfg as a dict 15 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 16 | with open(gnd_fname, 'rb') as f: 17 | cfg = pickle.load(f) 18 | cfg['gnd_fname'] = gnd_fname 19 | cfg['ext'] = '.jpg' 20 | cfg['qext'] = '.jpg' 21 | elif dataset == 'revisitop1m': 22 | # loading imlist from a .txt file 23 | cfg = {} 24 | cfg['imlist_fname'] = os.path.join(dir_main, dataset, '{}.txt'.format(dataset)) 25 | cfg['imlist'] = read_imlist(cfg['imlist_fname']) 26 | cfg['qimlist'] = [] 27 | cfg['ext'] = '' 28 | cfg['qext'] = '' 29 | 30 | 31 | cfg['dir_data'] = os.path.join(dir_main, dataset) 32 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 33 | cfg['n'] = len(cfg['imlist']) 34 | cfg['nq'] = len(cfg['qimlist']) 35 | cfg['im_fname'] = config_imname 36 | cfg['qim_fname'] = config_qimname 37 | cfg['dataset'] = dataset 38 | self.cfg = cfg 39 | 40 | self.n_samples = len(cfg["qimlist"]) if split == "query" else len(cfg["imlist"]) 41 | self.transform = transform 42 | self.imsize = imsize 43 | self.split = split 44 | 45 | def __len__(self): 46 | return self.n_samples 47 | 48 | def __getitem__(self, index): 49 | if self.split == "query": 50 | path = config_qimname(self.cfg, index) 51 | else: 52 | path = config_imname(self.cfg, index) 53 | ImageFile.LOAD_TRUNCATED_IMAGES = True 54 | with open(path, 'rb') as f: 55 | img = Image.open(f) 56 | img = img.convert('RGB') 57 | if self.imsize is not None: 58 | img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, index 62 | 63 | 64 | def config_imname(cfg, i): 65 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 66 | 67 | 68 | def config_qimname(cfg, i): 69 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 70 | 71 | 72 | def read_imlist(imlist_fn): 73 | with open(imlist_fn, "r") as file: 74 | imlist = file.read().splitlines() 75 | return imlist 76 | 77 | 78 | def compute_ap(ranks, nres): 79 | """ 80 | Computes average precision for given ranked indexes. 81 | 82 | Arguments 83 | --------- 84 | ranks : zero-based ranks of positive images 85 | nres : number of positive images 86 | 87 | Returns 88 | ------- 89 | ap : average precision 90 | """ 91 | 92 | # number of images ranked by the system 93 | nimgranks = len(ranks) 94 | 95 | # accumulate trapezoids in PR-plot 96 | ap = 0 97 | 98 | recall_step = 1. / nres 99 | 100 | for j in np.arange(nimgranks): 101 | rank = ranks[j] 102 | 103 | if rank == 0: 104 | precision_0 = 1. 105 | else: 106 | precision_0 = float(j) / rank 107 | 108 | precision_1 = float(j + 1) / (rank + 1) 109 | 110 | ap += (precision_0 + precision_1) * recall_step / 2. 111 | 112 | return ap 113 | 114 | 115 | def compute_map(ranks, gnd, kappas=[]): 116 | """ 117 | Computes the mAP for a given set of returned results. 118 | Usage: 119 | map = compute_map (ranks, gnd) 120 | computes mean average precsion (map) only 121 | 122 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 123 | computes mean average precision (map), average precision (aps) for each query 124 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 125 | 126 | Notes: 127 | 1) ranks starts from 0, ranks.shape = db_size X #queries 128 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 129 | 3) If there are no positive images for some query, that query is excluded from the evaluation 130 | """ 131 | 132 | map = 0. 133 | nq = len(gnd) # number of queries 134 | aps = np.zeros(nq) 135 | pr = np.zeros(len(kappas)) 136 | prs = np.zeros((nq, len(kappas))) 137 | nempty = 0 138 | 139 | for i in np.arange(nq): 140 | qgnd = np.array(gnd[i]['ok']) 141 | 142 | # no positive images, skip from the average 143 | if qgnd.shape[0] == 0: 144 | aps[i] = float('nan') 145 | prs[i, :] = float('nan') 146 | nempty += 1 147 | continue 148 | 149 | try: 150 | qgndj = np.array(gnd[i]['junk']) 151 | except: 152 | qgndj = np.empty(0) 153 | 154 | # sorted positions of positive and junk images (0 based) 155 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 156 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 157 | 158 | k = 0; 159 | ij = 0; 160 | if len(junk): 161 | # decrease positions of positives based on the number of 162 | # junk images appearing before them 163 | ip = 0 164 | while (ip < len(pos)): 165 | while (ij < len(junk) and pos[ip] > junk[ij]): 166 | k += 1 167 | ij += 1 168 | pos[ip] = pos[ip] - k 169 | ip += 1 170 | 171 | # compute ap 172 | ap = compute_ap(pos, len(qgnd)) 173 | map = map + ap 174 | aps[i] = ap 175 | 176 | # compute precision @ k 177 | pos += 1 # get it to 1-based 178 | for j in np.arange(len(kappas)): 179 | kq = min(max(pos), kappas[j]); 180 | prs[i, j] = (pos <= kq).sum() / kq 181 | pr = pr + prs[i, :] 182 | 183 | map = map / (nq - nempty) 184 | pr = pr / (nq - nempty) 185 | 186 | return map, aps, pr, prs 187 | 188 | 189 | if __name__ == "__main__": 190 | dataset = OxfordParisDataset( 191 | "/home/zju/jieche.mz/revisitop/data/datasets", 192 | "rparis6k", "query", transform=None, imsize=1024 193 | ) 194 | print(dataset.cfg["gnd"][0]["bbx"]) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | 9 | from utils import get_world_size, get_rank 10 | 11 | 12 | class MultiTeacherDistillLoss(nn.Module): 13 | def __init__(self, st=0.05, tt=0.05, s=None, teachers=None): 14 | """ 15 | st: student temperature 16 | tt: teacher temperature 17 | s: strategy 18 | teachers: list of teacher models 19 | """ 20 | super().__init__() 21 | self.st = st 22 | self.tt = tt 23 | assert s is not None and teachers is not None 24 | self.s = s 25 | self.teachers = teachers 26 | self.distill_loss_fn = nn.KLDivLoss(reduction="batchmean") 27 | self.register_buffer("pos_win_count", torch.zeros(len(self.teachers), dtype=torch.long)) 28 | 29 | def forward(self, stu1, stu2, tch1, tch2): 30 | """ 31 | stu1: B x D, representations of one view 32 | stu2: B x D, representations of another view 33 | tch1: list of B x D, representatioons of one view of all teachers 34 | tch2: list of B x D, representatioons of another view of all teachers 35 | """ 36 | stu1 = nn.functional.normalize(stu1, p=2, dim=-1) 37 | stu2 = nn.functional.normalize(stu2, p=2, dim=-1) 38 | tch1 = [nn.functional.normalize(t1, p=2, dim=-1) for t1 in tch1] 39 | tch2 = [nn.functional.normalize(t2, p=2, dim=-1) for t2 in tch2] 40 | 41 | # gather features from other devices without loss of gradients 42 | stu1 = self.gather_with_grad(stu1) 43 | stu2 = self.gather_with_grad(stu2) 44 | tch1 = [self.gather_with_grad(t1.contiguous()) for t1 in tch1] 45 | tch2 = [self.gather_with_grad(t2.contiguous()) for t2 in tch2] 46 | 47 | stu_sim_mat = stu1.mm(stu2.t()) 48 | 49 | tch_sim_mats = [t1.mm(t2.t()) for t1, t2 in zip(tch1, tch2)] 50 | 51 | distill_loss = \ 52 | self.get_distill_loss( 53 | stu_sim_mat, 54 | tch_sim_mats, 55 | tt=self.tt, 56 | st=self.st 57 | ) + \ 58 | self.get_distill_loss( 59 | stu_sim_mat.t(), 60 | [tch_sim_mat.t() for tch_sim_mat in tch_sim_mats], 61 | tt=self.tt, 62 | st=self.st 63 | ) 64 | 65 | # scale loss value due to the gradient mean reduction mechanism in distributed training 66 | # see https://pytorch.org/docs/master/notes/ddp.html 67 | distill_loss *= get_world_size() 68 | 69 | return {"distill loss": distill_loss} 70 | 71 | def get_distill_loss(self, stu_sim, tch_sims, tt=0.05, st=0.05): 72 | # T x N 73 | tch_sims = torch.stack(tch_sims, dim=0) 74 | 75 | if self.s == "mean": 76 | tch_sim = self.s_mean(tch_sims) 77 | elif self.s == "maxmin": 78 | tch_sim = self.s_maxmin(tch_sims) 79 | elif self.s == "maxmean": 80 | tch_sim = self.s_maxmean(tch_sims) 81 | elif self.s == "maxrand": 82 | tch_sim = self.s_maxrand(tch_sims) 83 | elif self.s == "rand": 84 | tch_sim = self.s_rand(tch_sims) 85 | 86 | t = nn.functional.softmax(tch_sim.div(tt), dim=-1) 87 | 88 | s = nn.functional.log_softmax(stu_sim.div(st), dim=-1) 89 | 90 | return self.distill_loss_fn(s, t) 91 | 92 | def s_mean(self, sims): 93 | return sims.mean(dim=0) 94 | 95 | def s_maxmin(self, sims): 96 | sim_diag, max_indices = sims.max(dim=0) 97 | max_indices = max_indices.diagonal() 98 | for i in range(self.pos_win_count.size(0)): 99 | self.pos_win_count[i] += (max_indices == i).sum() 100 | 101 | sim_off_diag = sims.min(dim=0)[0] 102 | mask = torch.eye(sims.size(1), m=sims.size(2), dtype=torch.bool, device=sims.device) 103 | fusion_sim = sim_diag * mask + sim_off_diag * mask.logical_not() 104 | return fusion_sim 105 | 106 | def s_maxmean(self, sims): 107 | sim_diag, max_indices = sims.max(dim=0) 108 | max_indices = max_indices.diagonal() 109 | for i in range(self.pos_win_count.size(0)): 110 | self.pos_win_count[i] += (max_indices == i).sum() 111 | 112 | sim_off_diag = sims.mean(dim=0) 113 | mask = torch.eye(sims.size(1), m=sims.size(2), dtype=torch.bool, device=sims.device) 114 | fusion_sim = sim_diag * mask + sim_off_diag * mask.logical_not() 115 | return fusion_sim 116 | 117 | def s_maxfix(self, sims): 118 | sim_diag, max_indices = sims.max(dim=0) 119 | max_indices = max_indices.diagonal() 120 | for i in range(self.pos_win_count.size(0)): 121 | self.pos_win_count[i] += (max_indices == i).sum() 122 | 123 | sim_off_diag = sims[-1] 124 | mask = torch.eye(sims.size(1), m=sims.size(2), dtype=torch.bool, device=sims.device) 125 | fusion_sim = sim_diag * mask + sim_off_diag * mask.logical_not() 126 | return fusion_sim 127 | 128 | def s_maxrand(self, sims): 129 | sim_diag, max_indices = sims.max(dim=0) 130 | max_indices = max_indices.diagonal() 131 | for i in range(self.pos_win_count.size(0)): 132 | self.pos_win_count[i] += (max_indices == i).sum() 133 | 134 | mask = torch.randint(0, len(sims), sims[0].size(), device=sims.device) 135 | mask = nn.functional.one_hot(mask, num_classes=len(sims)).permute(2, 0, 1) 136 | sim_off_diag = (sims * mask).sum(dim=0) 137 | mask = torch.eye(sims.size(1), m=sims.size(2), dtype=torch.bool, device=sims.device) 138 | fusion_sim = sim_diag * mask + sim_off_diag * mask.logical_not() 139 | return fusion_sim 140 | 141 | def s_rand(self, sims): 142 | mask = torch.randint(0, len(sims), sims[0].size(), device=sims.device) 143 | mask = nn.functional.one_hot(mask, num_classes=len(sims)).permute(2, 0, 1) 144 | fusion_sim = (sims * mask).sum(dim=0) 145 | return fusion_sim 146 | 147 | def get_ctr_loss(self, sim): 148 | target = torch.arange(sim.size(0), dtype=torch.long, device=sim.device) 149 | 150 | return self.ctr_loss_fn(sim.div(self.ct), target) 151 | 152 | def gather_with_grad(self, x): 153 | world_size = get_world_size() 154 | rank = get_rank() 155 | 156 | gather_x = [torch.zeros_like(x) for _ in range(world_size)] 157 | dist.all_gather(gather_x, x) 158 | # tensor returned by all_gather does not have gradients 159 | # reassign x to recover gradients 160 | gather_x[rank] = x 161 | 162 | gather_x = torch.cat(gather_x, dim=0) 163 | 164 | return gather_x 165 | 166 | def __repr__(self): 167 | return "positive count: " + \ 168 | " | ".join([f"{self.teachers[i]}: {self.pos_win_count[i]/self.pos_win_count.sum()*100:.2f}%" for i in range(len(self.teachers))]) -------------------------------------------------------------------------------- /svd_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import torch.nn as nn 10 | import torchvision.transforms as T 11 | 12 | from models.distill_model import archs 13 | from models.teachers import teacher_models 14 | from models.gem_pooling import GeneralizedMeanPooling 15 | from datasets.svd.core import Video, MetaData 16 | from datasets.svd.loader import DistributedTestLoader 17 | from datasets.svd.eval import DistEvaluator 18 | import utils 19 | 20 | 21 | def parse_args(): 22 | """ 23 | Parse input arguments 24 | """ 25 | parser = argparse.ArgumentParser(description="Knowledge Distill Evaluation on SVD dataset.") 26 | parser.add_argument( 27 | "-a", "--arch", default=None, type=str, metavar="ARCH", 28 | choices=list(archs.keys()), help="name of backbone model" 29 | ) 30 | parser.add_argument( 31 | "-t", "--teacher", default=None, type=str, metavar="ARCH", 32 | choices=list(teacher_models.keys()), help="name of teacher model" 33 | ) 34 | parser.add_argument( 35 | "-dm", "--dataset_meta", default="config/svd.yaml", type=str, metavar="FILE", 36 | help="dataset meta file" 37 | ) 38 | parser.add_argument( 39 | "--max_frames", default=60, type=int, metavar="N", 40 | help="max number of frames to truncate" 41 | ) 42 | parser.add_argument( 43 | "--stride", default=1, type=int, metavar="N", 44 | help="stride to sample frames", 45 | ) 46 | parser.add_argument( 47 | "-b", "--batch_size", default=16, type=int, metavar="N", 48 | help="test batch size" 49 | ) 50 | parser.add_argument( 51 | "--num_workers", default=8, type=int, metavar="N", 52 | help="number of data loader workers" 53 | ) 54 | parser.add_argument( 55 | "--topk", default=[100], nargs="+", metavar="N", 56 | help="to calculate topk metric" 57 | ) 58 | parser.add_argument( 59 | "--sim_fn", default="fmx", type=str, 60 | help="similarity function to use" 61 | ) 62 | parser.add_argument( 63 | "--subset_eval", action="store_true", 64 | help="eval full or subset" 65 | ) 66 | parser.add_argument( 67 | "-r", "--resume", default=None, type=str, metavar="DIR", 68 | help="checkpoint model to resume" 69 | ) 70 | parser.add_argument( 71 | "--world_size", default=8, type=int, 72 | help="number of workers" 73 | ) 74 | parser.add_argument( 75 | "--dist_url", default="tcp://localhost:1234", type=str, 76 | help='url used to set up distributed training' 77 | ) 78 | parser.add_argument( 79 | '--rank', default=0, type=int, 80 | help='node rank for distributed training' 81 | ) 82 | parser.add_argument( 83 | '--embed_dim', default=512, type=int, 84 | help='embedding dimension' 85 | ) 86 | parser.add_argument( 87 | '-p', default=1., type=float, 88 | help='power rate' 89 | ) 90 | return parser.parse_args() 91 | 92 | 93 | def main(): 94 | args = parse_args() 95 | 96 | utils.setup_logger() 97 | logger = logging.getLogger("svd_distill_eval") 98 | 99 | logger.info(vars(args)) 100 | 101 | ngpus_per_node = torch.cuda.device_count() 102 | 103 | mp.spawn(main_worker, nprocs=8, args=(ngpus_per_node, args)) 104 | 105 | 106 | def main_worker(gpu, ngpus_per_node, args): 107 | args.gpu = gpu 108 | 109 | if args.dist_url == "env://" and args.rank == -1: 110 | args.rank = int(os.environ["RANK"]) 111 | args.rank = args.rank * ngpus_per_node + gpu 112 | 113 | utils.setup_logger(log_path=None) 114 | logger = logging.getLogger("dist_worker " + str(args.rank)) 115 | 116 | dist.init_process_group( 117 | backend="nccl", 118 | init_method=args.dist_url, 119 | world_size=args.world_size, 120 | rank=args.rank 121 | ) 122 | 123 | torch.cuda.set_device(args.gpu) 124 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 125 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 126 | 127 | assert args.arch is None or args.teacher is None 128 | 129 | if args.arch is not None: 130 | model = archs[args.arch](pretrained=False, num_classes=args.embed_dim) 131 | model.avgpool = GeneralizedMeanPooling(args.p) 132 | 133 | if args.resume is not None: 134 | checkpoint_file = args.resume 135 | if os.path.isfile(checkpoint_file): 136 | logger.info(f"Loading checkpoint \"{checkpoint_file}\"...") 137 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 138 | else: 139 | logger.error(f"=> No checkpoint found at '{checkpoint_file}'.") 140 | sys.exit() 141 | 142 | state_dict = checkpoint["state_dict"] 143 | for k in list(state_dict.keys()): 144 | if k.startswith('module.base_encoder'): 145 | state_dict[k[len("module.base_encoder."):]] = state_dict.pop(k) 146 | else: 147 | state_dict.pop(k) 148 | model.load_state_dict(state_dict, strict=True) 149 | logger.info(f"Loaded checkpoint.") 150 | 151 | if args.teacher is not None: 152 | model = teacher_models[args.teacher](gem_p=args.p) 153 | 154 | model.cuda(device) 155 | 156 | if utils.has_batchnorms(model): 157 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 158 | 159 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 160 | 161 | dataset_meta_cfg = utils.load_config(args.dataset_meta) 162 | dataset_meta = MetaData(dataset_meta_cfg) 163 | 164 | test_dataset = Video( 165 | dataset_meta.frm_root_path, 166 | dataset_meta.frm_cnt, 167 | T.Compose([ 168 | T.Resize(256), 169 | T.CenterCrop(224), 170 | T.ToTensor() 171 | ]), 172 | args.max_frames, 173 | args.stride 174 | ) 175 | test_loader = DistributedTestLoader.build( 176 | test_dataset, dataset_meta, 177 | batch_size=args.batch_size, num_workers=args.num_workers 178 | ) 179 | 180 | test(model, test_loader, dataset_meta, device, args, not args.subset_eval) 181 | 182 | 183 | def test(model, test_loader, dataset_meta, device, args, full_eval): 184 | model.eval() 185 | 186 | evaluator = DistEvaluator( 187 | device, dataset_meta.test_groundtruth, args.topk, args.max_frames 188 | ) 189 | return evaluator( 190 | model, 191 | test_loader.query_loader, 192 | test_loader.labeled_loader, 193 | test_loader.unlabeled_loader, 194 | full_eval, 195 | args.sim_fn 196 | ) 197 | 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import time, datetime 4 | from collections import defaultdict, deque 5 | import utils 6 | import torch.distributed as dist 7 | from itertools import islice 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | 15 | def __init__(self, window_size=None, fmt=None): 16 | if fmt is None: 17 | fmt = "{median:.4f} ({global_avg:.4f})" 18 | self.deque = deque(maxlen=window_size) 19 | self.total = 0.0 20 | self.count = 0 21 | self.fmt = fmt 22 | 23 | def update(self, value, n=1): 24 | self.deque.append(value) 25 | self.count += n 26 | self.total += value * n 27 | 28 | def sync(self): 29 | if not utils.is_dist_avail_and_initialized(): 30 | return 31 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=torch.device("cuda:"+str(torch.cuda.current_device()))) 32 | dist.all_reduce(t, op=dist.ReduceOp.SUM, async_op=False) 33 | t = t.tolist() 34 | self.count = int(t[0]) 35 | self.total = t[1] 36 | 37 | @property 38 | def median(self): 39 | d = torch.tensor(list(self.deque), dtype=torch.float32) 40 | return d.median().item() 41 | 42 | @property 43 | def avg(self): 44 | d = torch.tensor(list(self.deque), dtype=torch.float32) 45 | return d.mean().item() 46 | 47 | @property 48 | def global_avg(self): 49 | return self.total / self.count 50 | 51 | @property 52 | def max(self): 53 | return max(self.deque) 54 | 55 | @property 56 | def value(self): 57 | return self.deque[-1] 58 | 59 | def __str__(self): 60 | return self.fmt.format( 61 | median=self.median, 62 | avg=self.avg, 63 | global_avg=self.global_avg, 64 | max=self.max, 65 | value=self.value) 66 | 67 | 68 | class MetricLogger(object): 69 | def __init__(self, logger, delimiter="\t"): 70 | self.meters = defaultdict(SmoothedValue) 71 | self.delimiter = delimiter 72 | self.logger = logger 73 | 74 | def update(self, **kwargs): 75 | for k, v in kwargs.items(): 76 | if isinstance(v, torch.Tensor): 77 | v = v.item() 78 | assert isinstance(v, (float, int)) 79 | self.meters[k].update(v) 80 | 81 | def __getattr__(self, attr): 82 | if attr in self.meters: 83 | return self.meters[attr] 84 | if attr in self.__dict__: 85 | return self.__dict__[attr] 86 | raise AttributeError("'{}' object has no attribute '{}'".format( 87 | type(self).__name__, attr)) 88 | 89 | def __str__(self): 90 | loss_str = [] 91 | for name, meter in self.meters.items(): 92 | loss_str.append( 93 | "{}: {}".format(name, str(meter)) 94 | ) 95 | return self.delimiter.join(loss_str) 96 | 97 | def add_meter(self, name, meter): 98 | self.meters[name] = meter 99 | 100 | def sync(self): 101 | for meter in self.meters.values(): 102 | meter.sync() 103 | 104 | def log_every(self, iterable, log_freq, header=None, iterations=None): 105 | iterations = len(iterable) if iterations is None else iterations 106 | if self.logger is None: 107 | for i, obj in enumerate(islice(iterable, 0, iterations)): 108 | yield i, obj 109 | return 110 | 111 | header = '' if header is None else header 112 | 113 | start_time = time.time() 114 | end = time.time() 115 | iter_time = SmoothedValue(fmt='{avg:.4f}') 116 | data_time = SmoothedValue(fmt='{avg:.4f}') 117 | space_fmt = ':' + str(len(str(iterations))) + 'd' 118 | if torch.cuda.is_available(): 119 | log_msg = self.delimiter.join([ 120 | header, 121 | 'Iter: [{0' + space_fmt + '}/{1}]', 122 | 'eta: {eta}', 123 | '{meters}', 124 | 'iter time: {time}', 125 | 'data time: {data}', 126 | 'gpu mem: {memory:.0f}MB' 127 | ]) 128 | else: 129 | log_msg = self.delimiter.join([ 130 | header, 131 | '[{0' + space_fmt + '}/{1}]', 132 | 'eta: {eta}', 133 | '{meters}', 134 | 'iter time: {time}', 135 | 'data time: {data}' 136 | ]) 137 | MB = 1024.0 * 1024.0 138 | for i, obj in enumerate(islice(iterable, 0, iterations)): 139 | data_time.update(time.time() - end) 140 | yield i, obj 141 | iter_time.update(time.time() - end) 142 | if i == iterations - 1 or i % log_freq == 0: 143 | eta_seconds = iter_time.global_avg * (iterations - i) 144 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 145 | if torch.cuda.is_available(): 146 | self.logger.info(log_msg.format( 147 | i+1, iterations, eta=eta_string, 148 | meters=str(self), 149 | time=str(iter_time), data=str(data_time), 150 | memory=torch.cuda.memory_reserved() / MB)) 151 | else: 152 | self.logger.info(log_msg.format( 153 | i+1, iterations, eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time))) 156 | end = time.time() 157 | 158 | total_time = time.time() - start_time 159 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 160 | end_msg = self.delimiter.join([ 161 | header, 162 | 'Total time: {0} ({1:.4f} s / it)' 163 | ]) 164 | self.logger.info(end_msg.format(total_time_str, total_time / iterations)) 165 | 166 | 167 | class MetricScorer: 168 | 169 | def __init__(self, k=0): 170 | self.k = k 171 | 172 | def score(self, sorted_labels): 173 | return 0.0 174 | 175 | def getLength(self, sorted_labels): 176 | length = self.k 177 | if length > len(sorted_labels) or length <= 0: 178 | length = len(sorted_labels) 179 | return length 180 | 181 | def name(self): 182 | if self.k > 0: 183 | return "%s@%d" % (self.__class__.__name__.replace("Scorer",""), self.k) 184 | return self.__class__.__name__.replace("Scorer","") 185 | 186 | def setLength(self, k): 187 | self.k = k; 188 | 189 | 190 | class APScorer(MetricScorer): 191 | 192 | def __init__(self, k=0): 193 | MetricScorer.__init__(self, k) 194 | 195 | def score(self, sorted_labels): 196 | length = self.getLength(sorted_labels) 197 | nr_relevant = len([x for x in sorted_labels[:length] if x > 0]) 198 | if nr_relevant == 0: 199 | return 0.0 200 | 201 | ap = 0.0 202 | rel = 0 203 | 204 | for i in range(length): 205 | lab = sorted_labels[i] 206 | if lab > 0: 207 | rel += 1 208 | ap += float(rel) / (i+1.0) 209 | ap /= nr_relevant 210 | 211 | return ap 212 | -------------------------------------------------------------------------------- /gld_pca_learn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import os 6 | import logging 7 | import random 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import torch.nn as nn 14 | import torchvision.transforms as T 15 | from PIL import Image 16 | 17 | from models.teachers import teacher_models 18 | import utils 19 | import metric 20 | 21 | 22 | def parse_args(): 23 | """ 24 | Parse input arguments 25 | """ 26 | parser = argparse.ArgumentParser(description="PCA training and saving.") 27 | parser.add_argument( 28 | "-t", "--teacher", default=None, type=str, metavar="ARCH", 29 | choices=list(teacher_models.keys()), help="name of teacher model" 30 | ) 31 | parser.add_argument( 32 | "--imsize", default=None, type=int, 33 | help="input image shape" 34 | ) 35 | parser.add_argument( 36 | "--num_workers", default=8, type=int, metavar="N", 37 | help="number of data loader workers" 38 | ) 39 | parser.add_argument( 40 | "--world_size", default=8, type=int, 41 | help="number of workers" 42 | ) 43 | parser.add_argument( 44 | "--dist_url", default="tcp://localhost:1234", type=str, 45 | help='url used to set up distributed training' 46 | ) 47 | parser.add_argument( 48 | '--rank', default=0, type=int, 49 | help='node rank for distributed training' 50 | ) 51 | parser.add_argument( 52 | '--embed_dim', default=512, type=int, 53 | help='embedding dimension' 54 | ) 55 | parser.add_argument( 56 | "-b", "--batch_size", default=256, type=int, 57 | help="batch size" 58 | ) 59 | parser.add_argument( 60 | '-p', default=3, type=float, 61 | help='power rate' 62 | ) 63 | parser.add_argument( 64 | '--num_samples', default=10000, type=int, 65 | help='number of samples to train PCA transform' 66 | ) 67 | parser.add_argument( 68 | "--gld_root_path", default="/path/to/gldv2", type=str, metavar="PATH", 69 | help="frame root path of GLDv2 dataset" 70 | ) 71 | parser.add_argument( 72 | "--dump_to", required=True, type=str, 73 | help="dump learned pca transformer to" 74 | ) 75 | return parser.parse_args() 76 | 77 | 78 | def main(): 79 | random.seed(0) 80 | 81 | args = parse_args() 82 | 83 | utils.setup_logger() 84 | logger = logging.getLogger("pca training") 85 | 86 | logger.info(vars(args)) 87 | 88 | ngpus_per_node = torch.cuda.device_count() 89 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 90 | 91 | 92 | def main_worker(gpu, ngpus_per_node, args): 93 | cudnn.benchmark = True 94 | 95 | args.gpu = gpu 96 | 97 | if args.dist_url == "env://" and args.rank == -1: 98 | args.rank = int(os.environ["RANK"]) 99 | args.rank = args.rank * ngpus_per_node + gpu 100 | 101 | utils.setup_logger(log_path=None) 102 | logger = logging.getLogger("worker " + str(args.rank)) 103 | 104 | dist.init_process_group( 105 | backend="nccl", 106 | init_method=args.dist_url, 107 | world_size=args.world_size, 108 | rank=args.rank 109 | ) 110 | 111 | torch.cuda.set_device(args.gpu) 112 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 113 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 114 | 115 | model = teacher_models[args.teacher](gem_p=args.p) 116 | model.cuda(device) 117 | if utils.has_batchnorms(model): 118 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 119 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 120 | 121 | pca_wtn = learn_pca_whitening(args, model) 122 | 123 | if args.rank == 0: 124 | torch.save(pca_wtn.state_dict(), args.dump_to) 125 | 126 | 127 | @torch.no_grad() 128 | def learn_pca_whitening(args, model): 129 | model.eval() 130 | logger = logging.getLogger("pca training") 131 | 132 | image_id_list = list() 133 | with open(os.path.join(args.gld_root_path, "meta/train_clean_ids.txt"), "r") as f: 134 | for l in f: 135 | image_id_list.append(l.strip()) 136 | 137 | logger.info(f"totally {len(image_id_list)} images.") 138 | 139 | num_samples = args.num_samples 140 | if num_samples > 0: 141 | logger.info(f"select {num_samples} samples to train pca whitening.") 142 | random.shuffle(image_id_list) 143 | image_id_list = image_id_list[:num_samples] 144 | else: 145 | logger.info(f"select all samples to train pca whitening.") 146 | image_id_list = image_id_list 147 | 148 | transform_list = [ 149 | T.RandomResizedCrop(args.imsize, scale=(0.4, 1.)), 150 | T.RandomHorizontalFlip(p=0.5), 151 | T.ToTensor(), 152 | T.Normalize( 153 | mean=[0.485, 0.456, 0.406], 154 | std=[0.229, 0.224, 0.225] 155 | ) 156 | ] 157 | if args.teacher.endswith("delg") or args.teacher.endswith("dolg"): 158 | transform_list.append(utils.RGB2BGR()) 159 | transform = T.Compose(transform_list) 160 | train_dataset = GLDv2( 161 | os.path.join(args.gld_root_path, "train"), image_id_list, 162 | transform=transform 163 | ) 164 | train_sampler = torch.utils.data.distributed.DistributedSampler( 165 | train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=True 166 | ) 167 | train_loader = torch.utils.data.DataLoader( 168 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 169 | num_workers=args.num_workers, pin_memory=True 170 | ) 171 | 172 | rank = utils.get_rank() 173 | if rank == 0: 174 | pca_wtn = utils.PCA(dim1=model.module.embed_dim, dim2=args.embed_dim) 175 | 176 | logger.info("extracting pca training features") 177 | features = extract_features(model, train_loader) 178 | 179 | if rank == 0: 180 | pca_wtn.train_pca(features) 181 | return pca_wtn 182 | else: 183 | return None 184 | 185 | 186 | @torch.no_grad() 187 | def extract_features(model, data_loader): 188 | rank = utils.get_rank() 189 | if rank == 0: 190 | logger = logging.getLogger("extract_feature") 191 | else: 192 | logger = None 193 | metric_logger = metric.MetricLogger(logger, delimiter=" ") 194 | log_freq = len(data_loader) // 16 if len(data_loader) >= 16 else len(data_loader) 195 | if rank == 0: 196 | features = torch.zeros(len(data_loader.dataset), model.module.embed_dim) 197 | for batch_idx, (samples, index) in metric_logger.log_every(data_loader, log_freq): 198 | samples = samples.cuda(non_blocking=True) 199 | index = index.cuda(non_blocking=True) 200 | 201 | feats = model(samples) 202 | 203 | feats = nn.functional.normalize(feats, p=2, dim=-1) 204 | 205 | index_all = gather_to_main(index) 206 | feats_all = gather_to_main(feats) 207 | 208 | if rank == 0: 209 | features.index_copy_(0, index_all.cpu(), feats_all.cpu()) 210 | 211 | if rank == 0: 212 | return features 213 | else: 214 | return None 215 | 216 | 217 | def gather_to_main(x): 218 | world_size = utils.get_world_size() 219 | rank = utils.get_rank() 220 | 221 | if rank == 0: 222 | gather_x = [torch.zeros_like(x) for _ in range(world_size)] 223 | else: 224 | gather_x = None 225 | dist.gather(x, gather_x if rank == 0 else None, dst=0) 226 | 227 | if rank == 0: 228 | gather_x = torch.cat(gather_x, dim=0) 229 | 230 | return gather_x 231 | 232 | 233 | class GLDv2(torch.utils.data.dataset.Dataset): 234 | def __init__(self, root_path, img_id_list, transform=None): 235 | super().__init__() 236 | self.root_path = root_path 237 | self.img_id_list = img_id_list 238 | 239 | self.t = transform 240 | 241 | def __getitem__(self, i): 242 | img_id = self.img_id_list[i] 243 | img_path = os.path.join( 244 | self.root_path, 245 | img_id[0], img_id[1], img_id[2], 246 | img_id+".jpg" 247 | ) 248 | 249 | img = Image.open(img_path) 250 | img = img.convert("RGB") 251 | if self.t is not None: 252 | img = self.t(img) 253 | 254 | return img, i 255 | 256 | def __len__(self): 257 | return len(self.img_id_list) 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /models/dolg/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # Written by feymanpriv 8 | 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Parameter 14 | 15 | from .net import init_weights 16 | from .config import cfg 17 | from .resnet import ResNet, ResHead 18 | from .resnet import GeneralizedMeanPoolingP 19 | 20 | """ Dolg models """ 21 | 22 | class DOLG(nn.Module): 23 | """ DOLG model """ 24 | def __init__(self): 25 | super(DOLG, self).__init__() 26 | self.pool_l= nn.AdaptiveAvgPool2d((1, 1)) 27 | self.pool_g = GeneralizedMeanPoolingP(norm=3.0) 28 | self.fc_t = nn.Linear(cfg.MODEL.S4_DIM, cfg.MODEL.S3_DIM, bias=True) 29 | self.fc = nn.Linear(cfg.MODEL.S4_DIM, cfg.MODEL.HEADS.REDUCTION_DIM, bias=True) 30 | self.globalmodel = ResNet() 31 | self.localmodel = SpatialAttention2d(cfg.MODEL.S3_DIM) 32 | self.desc_cls = Arcface(cfg.MODEL.HEADS.REDUCTION_DIM, cfg.MODEL.NUM_CLASSES) 33 | # self.embed_dim = cfg.MODEL.HEADS.REDUCTION_DIM 34 | self.embed_dim = cfg.MODEL.S3_DIM 35 | 36 | def forward(self, x): 37 | """ Global and local orthogonal fusion """ 38 | f3, f4 = self.globalmodel(x) 39 | # fl, _ = self.localmodel(f3) 40 | 41 | fg_o = self.pool_g(f4) 42 | fg_o = fg_o.view(fg_o.size(0), cfg.MODEL.S4_DIM) 43 | 44 | fg = self.fc_t(fg_o) 45 | # fg_norm = torch.norm(fg, p=2, dim=1) 46 | 47 | # proj = torch.bmm(fg.unsqueeze(1), torch.flatten(fl, start_dim=2)) 48 | # proj = torch.bmm(fg.unsqueeze(2), proj).view(fl.size()) 49 | # proj = proj / (fg_norm * fg_norm).view(-1, 1, 1, 1) 50 | # orth_comp = fl - proj 51 | 52 | # fo = self.pool_l(orth_comp) 53 | # fo = fo.view(fo.size(0), cfg.MODEL.S3_DIM) 54 | 55 | # final_feat=torch.cat((fg, fo), 1) 56 | # global_feature = self.fc(final_feat) 57 | 58 | # global_logits = self.desc_cls(global_feature, targets) 59 | return fg 60 | 61 | ''' 62 | def forward(self, x, targets): 63 | """ Global and local orthogonal fusion """ 64 | feamap3, feamap4 = self.globalmodel(x) 65 | 66 | g_f = self.pool_g(feamap4) 67 | b, c, h, w = g_f.size(0), g_f.size(1), g_f.size(2), g_f.size(3) 68 | x = g_f.view(b, -1) 69 | x = self.fc_t(x) 70 | g_f = x.view(b, c // 2, h, w) 71 | e_f = g_f.expand_as(feamap3) 72 | 73 | local_feamap3, _ = self.localmodel(feamap3) 74 | proj = torch.sum(e_f * local_feamap3, dim=1) / 75 | torch.sum(e_f * e_f, dim=1).unsqueeze(1) * e_f 76 | 77 | orth_comp = local_feamap3 - proj 78 | p_f = self.pool_l(orth_feamap3) 79 | p_f = p_f.view(p_f.size(0), -1) 80 | g_f = g_f.view(g_f.size(0), -1) 81 | 82 | global_feature=torch.cat((g_f, p_f), 1) 83 | global_feature = self.fc(global_feature) 84 | 85 | global_logits = self.desc_cls(global_feature, targets) 86 | return global_feature, global_logits 87 | ''' 88 | 89 | 90 | class SpatialAttention2d(nn.Module): 91 | ''' 92 | SpatialAttention2d 93 | 2-layer 1x1 conv network with softplus activation. 94 | ''' 95 | def __init__(self, in_c, act_fn='relu', with_aspp=cfg.MODEL.WITH_MA): 96 | super(SpatialAttention2d, self).__init__() 97 | 98 | self.with_aspp = with_aspp 99 | if self.with_aspp: 100 | self.aspp = ASPP(cfg.MODEL.S3_DIM) 101 | self.conv1 = nn.Conv2d(in_c, cfg.MODEL.S3_DIM, 1, 1) 102 | self.bn = nn.BatchNorm2d(cfg.MODEL.S3_DIM, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 103 | if act_fn.lower() in ['relu']: 104 | self.act1 = nn.ReLU() 105 | elif act_fn.lower() in ['leakyrelu', 'leaky', 'leaky_relu']: 106 | self.act1 = nn.LeakyReLU() 107 | self.conv2 = nn.Conv2d(cfg.MODEL.S3_DIM, 1, 1, 1) 108 | self.softplus = nn.Softplus(beta=1, threshold=20) # use default setting. 109 | 110 | for conv in [self.conv1, self.conv2]: 111 | conv.apply(init_weights) 112 | 113 | def forward(self, x): 114 | ''' 115 | x : spatial feature map. (b x c x w x h) 116 | att : softplus attention score 117 | ''' 118 | if self.with_aspp: 119 | x = self.aspp(x) 120 | x = self.conv1(x) 121 | x = self.bn(x) 122 | 123 | feature_map_norm = F.normalize(x, p=2, dim=1) 124 | 125 | x = self.act1(x) 126 | x = self.conv2(x) 127 | 128 | att_score = self.softplus(x) 129 | att = att_score.expand_as(feature_map_norm) 130 | x = att * feature_map_norm 131 | return x, att_score 132 | 133 | def __repr__(self): 134 | return self.__class__.__name__ 135 | 136 | 137 | class ASPP(nn.Module): 138 | ''' 139 | Atrous Spatial Pyramid Pooling Module 140 | ''' 141 | def __init__(self, in_c): 142 | super(ASPP, self).__init__() 143 | 144 | self.aspp = [] 145 | self.aspp.append(nn.Conv2d(in_c, 512, 1, 1)) 146 | 147 | for dilation in [6, 12, 18]: 148 | _padding = (dilation * 3 - dilation) // 2 149 | self.aspp.append(nn.Conv2d(in_c, 512, 3, 1, padding=_padding, dilation=dilation)) 150 | self.aspp = nn.ModuleList(self.aspp) 151 | 152 | self.im_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), 153 | nn.Conv2d(in_c, 512, 1, 1), 154 | nn.ReLU()) 155 | conv_after_dim = 512 * (len(self.aspp)+1) 156 | self.conv_after = nn.Sequential(nn.Conv2d(conv_after_dim, 1024, 1, 1), nn.ReLU()) 157 | 158 | for dilation_conv in self.aspp: 159 | dilation_conv.apply(init_weights) 160 | for model in self.im_pool: 161 | if isinstance(model, nn.Conv2d): 162 | model.apply(init_weights) 163 | for model in self.conv_after: 164 | if isinstance(model, nn.Conv2d): 165 | model.apply(init_weights) 166 | 167 | def forward(self, x): 168 | h, w = x.size(2), x.size(3) 169 | aspp_out = [F.interpolate(self.im_pool(x), scale_factor=(h,w), mode="bilinear", align_corners=False)] 170 | for i in range(len(self.aspp)): 171 | aspp_out.append(self.aspp[i](x)) 172 | aspp_out = torch.cat(aspp_out, 1) 173 | x = self.conv_after(aspp_out) 174 | return x 175 | 176 | 177 | class Arcface(nn.Module): 178 | """ Additive Angular Margin Loss """ 179 | def __init__(self, in_feat, num_classes): 180 | super().__init__() 181 | self.in_feat = in_feat 182 | self._num_classes = num_classes 183 | self._s = cfg.MODEL.HEADS.SCALE 184 | self._m = cfg.MODEL.HEADS.MARGIN 185 | 186 | self.cos_m = math.cos(self._m) 187 | self.sin_m = math.sin(self._m) 188 | self.threshold = math.cos(math.pi - self._m) 189 | self.mm = math.sin(math.pi - self._m) * self._m 190 | 191 | self.weight = Parameter(torch.Tensor(num_classes, in_feat)) 192 | self.register_buffer('t', torch.zeros(1)) 193 | 194 | def forward(self, features, targets): 195 | # get cos(theta) 196 | cos_theta = F.linear(F.normalize(features), F.normalize(self.weight)) 197 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 198 | 199 | target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1) 200 | 201 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 202 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 203 | mask = cos_theta > cos_theta_m 204 | final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm) 205 | 206 | hard_example = cos_theta[mask] 207 | with torch.no_grad(): 208 | self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t 209 | cos_theta[mask] = hard_example * (self.t + hard_example) 210 | cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit) 211 | pred_class_logits = cos_theta * self._s 212 | return pred_class_logits 213 | 214 | def extra_repr(self): 215 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format( 216 | self.in_feat, self._num_classes, self._s, self._m 217 | ) 218 | 219 | -------------------------------------------------------------------------------- /models/teachers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from .resnet import resnet50, resnet101 8 | from .solar.networks import ResNetSOAs 9 | from .delg.model import R101_DELG 10 | from .dolg.model import DOLG 11 | 12 | from .gem_pooling import GeneralizedMeanPooling 13 | 14 | 15 | def resnet101_gem(path_to_pretrained_weights, gem_p=3., **kwargs): 16 | 17 | class ResNet101_GeM(nn.Module): 18 | def __init__(self, backbone, fc): 19 | super().__init__() 20 | self.backbone = backbone 21 | self.fc = fc 22 | 23 | def forward(self, x): 24 | x = self.backbone(x) 25 | x = nn.functional.normalize(x, p=2, dim=-1) 26 | x = self.fc(x) 27 | return x 28 | 29 | pretrained_weights = os.path.join(path_to_pretrained_weights, "gl18-tl-resnet101-gem-w-a4d43db.pt") 30 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 31 | backbone = resnet101(pretrained=False, num_classes=0) 32 | backbone.avgpool = GeneralizedMeanPooling(gem_p) 33 | model = ResNet101_GeM( 34 | backbone, 35 | nn.Linear(backbone.embed_dim, checkpoint["meta"]["outputdim"]) 36 | ) 37 | model.embed_dim = checkpoint["meta"]["outputdim"] 38 | 39 | state_dict = checkpoint["state_dict"] 40 | fc_weight = state_dict.pop("fc.weight") 41 | fc_bias = state_dict.pop("fc.bias") 42 | 43 | model.backbone.load_state_dict(state_dict, strict=True) 44 | model.fc.load_state_dict({"weight": fc_weight, "bias": fc_bias}, strict=True) 45 | 46 | return model 47 | 48 | 49 | def resnet101_ap_gem(path_to_pretrained_weights, gem_p=3., **kwargs): 50 | pretrained_weights = os.path.join(path_to_pretrained_weights, "Resnet101-AP-GeM-LM18.pt") 51 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 52 | model = resnet101(pretrained=False, num_classes=checkpoint["model_options"]["out_dim"]) 53 | model.avgpool = GeneralizedMeanPooling(gem_p) 54 | 55 | state_dict = checkpoint['state_dict'] 56 | for k in list(state_dict.keys()): 57 | if k.startswith("module"): 58 | if k != "module.adpool.p": 59 | state_dict[k[len("module."):]] = state_dict[k] 60 | del state_dict[k] 61 | 62 | model.load_state_dict(state_dict, strict=True) 63 | model.embed_dim = checkpoint["model_options"]["out_dim"] 64 | 65 | return model 66 | 67 | 68 | def resnet101_solar(path_to_pretrained_weights, gem_p=3., **kwargs): 69 | 70 | class ResNet101_SOLAR(nn.Module): 71 | def __init__(self, backbone, fc, gem_p): 72 | super().__init__() 73 | self.backbone = backbone 74 | self.fc = fc 75 | self.pool = GeneralizedMeanPooling(gem_p) 76 | 77 | def forward(self, x): 78 | x = self.backbone(x) 79 | x = self.pool(x) 80 | x = nn.functional.normalize(x, p=2, dim=-1) 81 | x = self.fc(x) 82 | return x 83 | 84 | pretrained_weights = os.path.join(path_to_pretrained_weights, "resnet101-solar-best.pth") 85 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 86 | backbone = ResNetSOAs() 87 | model = ResNet101_SOLAR( 88 | backbone, 89 | nn.Linear(backbone.embed_dim, checkpoint["meta"]["outputdim"]), 90 | gem_p 91 | ) 92 | model.embed_dim = checkpoint["meta"]["outputdim"] 93 | 94 | state_dict = checkpoint["state_dict"] 95 | state_dict.pop("pool.p") 96 | fc_weight = state_dict.pop("whiten.weight") 97 | fc_bias = state_dict.pop("whiten.bias") 98 | 99 | for k in list(state_dict.keys()): 100 | state_dict[k[len("features."):]] = state_dict.pop(k) 101 | 102 | model.backbone.load_state_dict(state_dict, strict=True) 103 | model.fc.load_state_dict({"weight": fc_weight, "bias": fc_bias}, strict=True) 104 | 105 | return model 106 | 107 | 108 | def resnet101_delg(path_to_pretrained_weights, pretrained=True, gem_p=3.): 109 | pretrained_weights = os.path.join(path_to_pretrained_weights, "r101_delg_s512.pyth") 110 | model = R101_DELG() 111 | 112 | if pretrained: 113 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 114 | state_dict = checkpoint['model_state'] 115 | state_dict["globalmodel.head.pool.p"] = torch.tensor([gem_p], dtype=torch.float32) 116 | for k in list(state_dict.keys()): 117 | if not k.startswith("globalmodel"): 118 | state_dict.pop(k) 119 | model.load_state_dict(state_dict, strict=True) 120 | 121 | return model 122 | 123 | 124 | def resnet101_dolg(path_to_pretrained_weights, pretrained=True, gem_p=3.): 125 | pretrained_weights = os.path.join(path_to_pretrained_weights, "r101_dolg_512.pyth") 126 | model = DOLG() 127 | 128 | if pretrained: 129 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 130 | state_dict = checkpoint['model_state'] 131 | state_dict.pop("pool1.p") 132 | state_dict["pool_g.p"] = torch.tensor([gem_p], dtype=torch.float32) 133 | model.load_state_dict(state_dict, strict=True) 134 | 135 | return model 136 | 137 | 138 | def mocov3(path_to_pretrained_weights, pretrained=True, with_head=False, gem_p=3.): 139 | pretrained_moco = os.path.join(path_to_pretrained_weights, 'mocov3_1000ep.pth.tar') 140 | 141 | model = resnet50(pretrained=False, num_classes=0) 142 | model.avgpool = GeneralizedMeanPooling(gem_p) 143 | 144 | if with_head: 145 | model.fc = nn.Sequential( 146 | # projector 147 | nn.Linear(2048, 4096, bias=False), 148 | nn.BatchNorm1d(4096), 149 | nn.ReLU(inplace=True), 150 | nn.Linear(4096, 256, bias=False), 151 | nn.BatchNorm1d(256, affine=False), 152 | # predictor 153 | nn.Linear(256, 4096, bias=False), 154 | nn.BatchNorm1d(4096), 155 | nn.ReLU(inplace=True), 156 | nn.Linear(4096, 256, bias=False) 157 | ) 158 | model.embed_dim = 256 159 | 160 | if pretrained: 161 | checkpoint = torch.load(pretrained_moco, map_location="cpu") 162 | state_dict = checkpoint['state_dict'] 163 | for k in list(state_dict.keys()): 164 | if k.startswith("module.base_encoder"): 165 | if k.startswith("module.base_encoder.fc") and not with_head: 166 | state_dict.pop(k) 167 | continue 168 | state_dict[k[len("module.base_encoder."):]] = state_dict.pop(k) 169 | elif k.startswith("module.predictor") and with_head: 170 | i = int(k[len("module.predictor."):][0]) 171 | new_k = "fc."+str(i + 5)+k[len("module.predictor.0"):] 172 | state_dict[new_k] = state_dict.pop(k) 173 | else: 174 | state_dict.pop(k) 175 | model.load_state_dict(state_dict, strict=True) 176 | 177 | return model 178 | 179 | 180 | def barlowtwins(path_to_pretrained_weights, pretrained=True, with_head=False, gem_p=3.): 181 | model = resnet50(pretrained=False, num_classes=0) 182 | model.avgpool = GeneralizedMeanPooling(gem_p) 183 | 184 | if with_head: 185 | model.fc = nn.Sequential( 186 | nn.Linear(2048, 8192, bias=False), 187 | nn.BatchNorm1d(8192), 188 | nn.ReLU(inplace=True), 189 | nn.Linear(8192, 8192, bias=False), 190 | nn.BatchNorm1d(8192), 191 | nn.ReLU(inplace=True), 192 | nn.Linear(8192, 8192, bias=False), 193 | nn.BatchNorm1d(8192, affine=False) 194 | ) 195 | model.embed_dim = 8192 196 | 197 | if pretrained: 198 | checkpoint_file = os.path.join(path_to_pretrained_weights, 'barlowtwins_full.pth.tar') 199 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 200 | state_dict = checkpoint["model"] 201 | for k in list(state_dict.keys()): 202 | if k.startswith("module.backbone"): 203 | state_dict[k[len("module.backbone."):]] = state_dict.pop(k) 204 | elif k.startswith("module.projector") and with_head: 205 | state_dict["fc."+k[len("module.projector."):]] = state_dict.pop(k) 206 | elif k.startswith("module.bn") and with_head: 207 | state_dict["fc.7."+k[len("module.bn."):]] = state_dict.pop(k) 208 | else: 209 | state_dict.pop(k) 210 | model.load_state_dict(state_dict, strict=True) 211 | 212 | return model 213 | 214 | 215 | teacher_models = { 216 | "mocov3": mocov3, 217 | "barlowtwins": barlowtwins, 218 | "resnet101_gem": resnet101_gem, 219 | "resnet101_ap_gem": resnet101_ap_gem, 220 | "resnet101_solar": resnet101_solar, 221 | "resnet101_delg": resnet101_delg, 222 | "resnet101_dolg": resnet101_dolg, 223 | } 224 | 225 | 226 | if __name__ == "__main__": 227 | model = resnet101_delg() -------------------------------------------------------------------------------- /models/dolg/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ResNe(X)t models.""" 9 | 10 | import os 11 | import torch 12 | import torch.nn as nn 13 | from .net import init_weights 14 | from .config import cfg 15 | 16 | 17 | cfg.merge_from_file(os.path.join(os.path.dirname(__file__), "dolg_config.yaml")) 18 | cfg.freeze() 19 | 20 | # Stage depths for ImageNet models 21 | _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} 22 | 23 | 24 | def get_trans_fun(name): 25 | """Retrieves the transformation function by name.""" 26 | trans_funs = { 27 | "basic_transform": BasicTransform, 28 | "bottleneck_transform": BottleneckTransform, 29 | } 30 | err_str = "Transformation function '{}' not supported" 31 | assert name in trans_funs.keys(), err_str.format(name) 32 | return trans_funs[name] 33 | 34 | 35 | class ResHead(nn.Module): 36 | """ResNet head: AvgPool, 1x1.""" 37 | 38 | def __init__(self, w_in, nc): 39 | super(ResHead, self).__init__() 40 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 41 | self.fc = nn.Linear(w_in, nc, bias=True) 42 | 43 | def forward(self, x): 44 | x = self.avg_pool(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.fc(x) 47 | return x 48 | 49 | 50 | class GlobalHead(nn.Module): 51 | def __init__(self, w_in, nc): 52 | super(GlobalHead, self).__init__() 53 | self.pool = GeneralizedMeanPoolingP() 54 | self.fc = nn.Linear(w_in, nc, bias=True) 55 | 56 | def forward(self, x): 57 | x = self.pool(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.fc(x) 60 | return x 61 | 62 | 63 | class GeneralizedMeanPooling(nn.Module): 64 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 65 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 66 | - At p = infinity, one gets Max Pooling 67 | - At p = 1, one gets Average Pooling 68 | The output is of size H x W, for any input size. 69 | The number of output features is equal to the number of input planes. 70 | Args: 71 | output_size: the target output size of the image of the form H x W. 72 | Can be a tuple (H, W) or a single H for a square image H x H 73 | H and W can be either a ``int``, or ``None`` which means the size will 74 | be the same as that of the input. 75 | """ 76 | 77 | def __init__(self, norm, output_size=1, eps=1e-6): 78 | super(GeneralizedMeanPooling, self).__init__() 79 | assert norm > 0 80 | self.p = float(norm) 81 | self.output_size = output_size 82 | self.eps = eps 83 | 84 | def forward(self, x): 85 | x = x.clamp(min=self.eps).pow(self.p) 86 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 87 | 88 | def __repr__(self): 89 | return self.__class__.__name__ + '(' \ 90 | + str(self.p) + ', ' \ 91 | + 'output_size=' + str(self.output_size) + ')' 92 | 93 | 94 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 95 | """ Same, but norm is trainable 96 | """ 97 | 98 | def __init__(self, norm=3, output_size=1, eps=1e-6): 99 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 100 | self.p = nn.Parameter(torch.ones(1) * norm) 101 | 102 | 103 | class BasicTransform(nn.Module): 104 | """Basic transformation: 3x3, BN, ReLU, 3x3, BN.""" 105 | 106 | def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): 107 | err_str = "Basic transform does not support w_b and num_gs options" 108 | assert w_b is None and num_gs == 1, err_str 109 | super(BasicTransform, self).__init__() 110 | self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) 111 | self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 112 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 113 | self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) 114 | self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 115 | self.b_bn.final_bn = True 116 | 117 | def forward(self, x): 118 | for layer in self.children(): 119 | x = layer(x) 120 | return x 121 | 122 | 123 | class BottleneckTransform(nn.Module): 124 | """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN.""" 125 | 126 | def __init__(self, w_in, w_out, stride, w_b, num_gs): 127 | super(BottleneckTransform, self).__init__() 128 | # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 129 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) 130 | self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False) 131 | self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 132 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 133 | self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False) 134 | self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 135 | self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 136 | self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) 137 | self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 138 | self.c_bn.final_bn = True 139 | 140 | def forward(self, x): 141 | for layer in self.children(): 142 | x = layer(x) 143 | return x 144 | 145 | 146 | class ResBlock(nn.Module): 147 | """Residual block: x + F(x).""" 148 | 149 | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): 150 | super(ResBlock, self).__init__() 151 | # Use skip connection with projection if shape changes 152 | self.proj_block = (w_in != w_out) or (stride != 1) 153 | if self.proj_block: 154 | self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) 155 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 156 | self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) 157 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 158 | 159 | def forward(self, x): 160 | if self.proj_block: 161 | x = self.bn(self.proj(x)) + self.f(x) 162 | else: 163 | x = x + self.f(x) 164 | x = self.relu(x) 165 | return x 166 | 167 | 168 | class ResStage(nn.Module): 169 | """Stage of ResNet.""" 170 | 171 | def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): 172 | super(ResStage, self).__init__() 173 | for i in range(d): 174 | b_stride = stride if i == 0 else 1 175 | b_w_in = w_in if i == 0 else w_out 176 | trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) 177 | res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) 178 | self.add_module("b{}".format(i + 1), res_block) 179 | 180 | def forward(self, x): 181 | for block in self.children(): 182 | x = block(x) 183 | return x 184 | 185 | 186 | class ResStemIN(nn.Module): 187 | """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" 188 | 189 | def __init__(self, w_in, w_out): 190 | super(ResStemIN, self).__init__() 191 | self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) 192 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 193 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 194 | self.pool = nn.MaxPool2d(3, stride=2, padding=1) 195 | 196 | def forward(self, x): 197 | for layer in self.children(): 198 | x = layer(x) 199 | return x 200 | 201 | 202 | class ResNet(nn.Module): 203 | """ResNet model.""" 204 | 205 | def __init__(self): 206 | super(ResNet, self).__init__() 207 | self._construct() 208 | self.apply(init_weights) 209 | 210 | def _construct(self): 211 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP 212 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] 213 | w_b = gw * g 214 | self.stem = ResStemIN(3, 64) 215 | self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g) 216 | self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g) 217 | self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g) 218 | self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g) 219 | #self.head = ResHead(2048, nc=cfg.MODEL.HEADS.REDUCTION_DIM) 220 | #self.head = GlobalHead(2048, nc=cfg.MODEL.HEADS.REDUCTION_DIM) 221 | 222 | def forward(self, x): 223 | x = self.stem(x) 224 | x1 = self.s1(x) 225 | x2 = self.s2(x1) 226 | x3 = self.s3(x2) 227 | x4 = self.s4(x3) 228 | #x = self.head(x4) 229 | return x3, x4 230 | -------------------------------------------------------------------------------- /svd_pca_learn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | import random 6 | import numpy as np 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | import torch.nn as nn 13 | import torchvision.transforms as T 14 | from PIL import Image, ImageFilter, ImageOps 15 | 16 | from datasets.svd.core import MetaData 17 | from models.teachers import teacher_models 18 | from models.gem_pooling import GeneralizedMeanPooling 19 | import utils 20 | import metric 21 | 22 | 23 | def parse_args(): 24 | """ 25 | Parse input arguments 26 | """ 27 | parser = argparse.ArgumentParser(description="PCA training and saving.") 28 | parser.add_argument( 29 | "-t", "--teacher", default=None, type=str, metavar="ARCH", 30 | choices=list(teacher_models.keys()), help="name of teacher model" 31 | ) 32 | parser.add_argument( 33 | "--num_workers", default=8, type=int, metavar="N", 34 | help="number of data loader workers" 35 | ) 36 | parser.add_argument( 37 | "--world_size", default=8, type=int, 38 | help="number of workers" 39 | ) 40 | parser.add_argument( 41 | "--dist_url", default="tcp://localhost:1234", type=str, 42 | help='url used to set up distributed training' 43 | ) 44 | parser.add_argument( 45 | '--rank', default=0, type=int, 46 | help='node rank for distributed training' 47 | ) 48 | parser.add_argument( 49 | '--embed_dim', default=512, type=int, 50 | help='embedding dimension' 51 | ) 52 | parser.add_argument( 53 | "-b", "--batch_size", default=256, type=int, 54 | help="batch size" 55 | ) 56 | parser.add_argument( 57 | '-p', default=1, type=float, 58 | help='power rate' 59 | ) 60 | parser.add_argument( 61 | "--dump_to", required=True, type=str, 62 | help="dump learned pca transformer to" 63 | ) 64 | parser.add_argument( 65 | "-dm", "--dataset_meta", default="config/svd.yaml", type=str, metavar="FILE", 66 | help="dataset meta file" 67 | ) 68 | return parser.parse_args() 69 | 70 | 71 | def main(): 72 | args = parse_args() 73 | 74 | utils.setup_logger() 75 | logger = logging.getLogger("pca training") 76 | 77 | logger.info(vars(args)) 78 | 79 | ngpus_per_node = torch.cuda.device_count() 80 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 81 | 82 | 83 | def main_worker(gpu, ngpus_per_node, args): 84 | cudnn.benchmark = True 85 | 86 | args.gpu = gpu 87 | 88 | if args.dist_url == "env://" and args.rank == -1: 89 | args.rank = int(os.environ["RANK"]) 90 | args.rank = args.rank * ngpus_per_node + gpu 91 | 92 | utils.setup_logger(log_path=None) 93 | logger = logging.getLogger("worker " + str(args.rank)) 94 | 95 | dist.init_process_group( 96 | backend="nccl", 97 | init_method=args.dist_url, 98 | world_size=args.world_size, 99 | rank=args.rank 100 | ) 101 | 102 | torch.cuda.set_device(args.gpu) 103 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 104 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 105 | 106 | model = teacher_models[args.teacher](gem_p=args.p) 107 | model.cuda(device) 108 | if utils.has_batchnorms(model): 109 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 110 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 111 | 112 | pca_wtn = learn_pca_whitening(args, model) 113 | 114 | if args.rank == 0: 115 | torch.save(pca_wtn.state_dict(), args.dump_to) 116 | 117 | 118 | def sample_frames(video_ids, frm_cnt, n=60): 119 | random.seed(0) 120 | image_id_list = list() 121 | 122 | for video_id in video_ids: 123 | frame_ids = random.sample(list(range(frm_cnt[video_id])), min(n, frm_cnt[video_id])) 124 | image_ids = [f"{video_id}/{frame_id:04d}" for frame_id in frame_ids] 125 | image_id_list += image_ids 126 | 127 | return image_id_list 128 | 129 | 130 | @torch.no_grad() 131 | def learn_pca_whitening(args, model): 132 | model.eval() 133 | logger = logging.getLogger("pca training") 134 | 135 | dataset_meta_cfg = utils.load_config(args.dataset_meta) 136 | dataset_meta = MetaData(dataset_meta_cfg) 137 | 138 | 139 | # video_ids = dataset_meta.train_ids + dataset_meta.unlabeled_ids 140 | video_ids = dataset_meta.train_ids 141 | rank = utils.get_rank() 142 | if rank == 0: 143 | image_id_list = sample_frames(video_ids, dataset_meta.frm_cnt) 144 | objects = [image_id_list] 145 | else: 146 | objects = [None] 147 | dist.broadcast_object_list( 148 | objects, src=0, 149 | device=torch.device("cuda:"+str(torch.cuda.current_device())) 150 | ) 151 | image_id_list = objects[0] 152 | logger.info(f"totally {len(image_id_list)} images.") 153 | 154 | transform = T.Compose([ 155 | T.RandomResizedCrop(224, scale=(0.4, 1.)), 156 | T.RandomApply([ 157 | T.ColorJitter(0.4, 0.4, 0.2, 0.1) 158 | ], p=0.8), 159 | T.RandomGrayscale(p=0.2), 160 | T.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 161 | T.RandomApply([Solarize()], p=0.2), 162 | T.RandomHorizontalFlip(p=0.5), 163 | T.ToTensor(), 164 | T.Normalize( 165 | mean=[0.485, 0.456, 0.406], 166 | std=[0.229, 0.224, 0.225] 167 | ) 168 | ]) 169 | train_dataset = SVD( 170 | dataset_meta.frm_root_path, image_id_list, transform 171 | ) 172 | train_sampler = torch.utils.data.distributed.DistributedSampler( 173 | train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=True 174 | ) 175 | train_loader = torch.utils.data.DataLoader( 176 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 177 | num_workers=args.num_workers, pin_memory=True 178 | ) 179 | 180 | rank = utils.get_rank() 181 | if rank == 0: 182 | pca_wtn = utils.PCA(dim1=model.module.embed_dim, dim2=args.embed_dim) 183 | 184 | logger.info("extracting pca training features") 185 | features = extract_features(model, train_loader) 186 | 187 | if rank == 0: 188 | pca_wtn.train_pca(features) 189 | return pca_wtn 190 | else: 191 | return None 192 | 193 | 194 | @torch.no_grad() 195 | def extract_features(model, data_loader): 196 | rank = utils.get_rank() 197 | if rank == 0: 198 | logger = logging.getLogger("extract_feature") 199 | else: 200 | logger = None 201 | metric_logger = metric.MetricLogger(logger, delimiter=" ") 202 | log_freq = len(data_loader) // 16 if len(data_loader) >= 16 else len(data_loader) 203 | if rank == 0: 204 | features = torch.zeros(len(data_loader.dataset), model.module.embed_dim) 205 | for batch_idx, (samples, index) in metric_logger.log_every(data_loader, log_freq): 206 | samples = samples.cuda(non_blocking=True) 207 | index = index.cuda(non_blocking=True) 208 | 209 | feats = model(samples) 210 | 211 | feats = nn.functional.normalize(feats, p=2, dim=-1) 212 | 213 | index_all = gather_to_main(index) 214 | feats_all = gather_to_main(feats) 215 | 216 | if rank == 0: 217 | features.index_copy_(0, index_all.cpu(), feats_all.cpu()) 218 | 219 | if rank == 0: 220 | return features 221 | else: 222 | return None 223 | 224 | 225 | def gather_to_main(x): 226 | world_size = utils.get_world_size() 227 | rank = utils.get_rank() 228 | 229 | if rank == 0: 230 | gather_x = [torch.zeros_like(x) for _ in range(world_size)] 231 | else: 232 | gather_x = None 233 | dist.gather(x, gather_x if rank == 0 else None, dst=0) 234 | 235 | if rank == 0: 236 | gather_x = torch.cat(gather_x, dim=0) 237 | 238 | return gather_x 239 | 240 | 241 | class SVD(torch.utils.data.dataset.Dataset): 242 | def __init__(self, root_path, img_id_list, transform): 243 | super().__init__() 244 | self.root_path = root_path 245 | self.img_id_list = img_id_list 246 | self.t = transform 247 | 248 | def __getitem__(self, i): 249 | img_id = self.img_id_list[i] 250 | img_path = os.path.join( 251 | self.root_path, img_id+".jpg" 252 | ) 253 | 254 | img = Image.open(img_path) 255 | img = img.convert("RGB") 256 | img = self.t(img) 257 | 258 | return img, i 259 | 260 | def __len__(self): 261 | return len(self.img_id_list) 262 | 263 | 264 | class GaussianBlur(object): 265 | def __init__(self, sigma=[.1, 2.]): 266 | self.sigma = sigma 267 | 268 | def __call__(self, x): 269 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 270 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 271 | return x 272 | 273 | 274 | class Solarize(object): 275 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 276 | 277 | def __call__(self, x): 278 | return ImageOps.solarize(x) 279 | 280 | 281 | if __name__ == "__main__": 282 | main() 283 | -------------------------------------------------------------------------------- /datasets/svd/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import logging 5 | 6 | from PIL import Image 7 | from collections import defaultdict 8 | from torch.utils.data.dataset import Dataset 9 | 10 | 11 | __all__ = [ 12 | "Frame", 13 | "Video", 14 | "MetaData" 15 | ] 16 | 17 | 18 | class Frame(Dataset): 19 | """ 20 | Args: 21 | root_path (str): root_path of frame directory 22 | frame_count (dict): dict that maps video to its frame numbers 23 | max_frames (int): maximum frame numbers sampled for a video 24 | stride (int): frame sample stride 25 | """ 26 | def __init__(self, root_path, frm_cnt, transform): 27 | super().__init__() 28 | self._root_path = root_path 29 | self._frm_cnt = frm_cnt 30 | 31 | self.t = transform 32 | 33 | def __getitem__(self, k): 34 | """ 35 | Args: 36 | k (tuple): video id, frame id. 37 | 38 | Returns: 39 | tuple: preprocessed frame (Tensor), video id (str) 40 | """ 41 | video_id, frame_id = k 42 | num_of_frame = self._frm_cnt[video_id] 43 | assert frame_id < num_of_frame 44 | 45 | frame_file = f"{frame_id:04d}.jpg" 46 | frame_path = os.path.join(self._root_path, video_id, frame_file) 47 | 48 | frame = Image.open(frame_path) 49 | frame = frame.convert("RGB") 50 | frame = self.t(frame) 51 | 52 | return *frame, f"{video_id}/{frame_file}" 53 | 54 | 55 | class Video(Dataset): 56 | """ 57 | Args: 58 | root_path (str): root_path of frame directory 59 | frame_count (dict): dict that maps video to its frame numbers 60 | max_frames (int): maximum frame numbers sampled for a video 61 | stride (int): frame sample stride 62 | """ 63 | def __init__(self, root_path, frm_cnt, transform, max_frames=60, stride=1): 64 | super().__init__() 65 | self._root_path = root_path 66 | self._frm_cnt = frm_cnt 67 | self._max_frames = max_frames 68 | self._stride = stride 69 | 70 | self.t = transform 71 | 72 | def __getitem__(self, k): 73 | """ 74 | Args: 75 | k (str): video id. 76 | 77 | Returns: 78 | tuple: preprocessed frames (Tensor), video id (str) 79 | """ 80 | video_id = k 81 | num_of_frame = self._frm_cnt[video_id] 82 | 83 | frame_ids = [f"{i:04d}.jpg" for i in range(0, min(num_of_frame, self._max_frames), self._stride)] 84 | 85 | frame_paths = [] 86 | for frame_id in frame_ids: 87 | frame_path = os.path.join(self._root_path, video_id, frame_id) 88 | frame_paths.append(frame_path) 89 | 90 | frames = [] 91 | for frame_path in frame_paths: 92 | frame = Image.open(frame_path) 93 | frame = frame.convert("RGB") 94 | frame = self.t(frame) 95 | frames.append(frame) 96 | frames = torch.stack(frames, dim=0) 97 | 98 | num_frame = frames.size(0) 99 | 100 | return frames, video_id, num_frame 101 | 102 | @property 103 | def max_frames(self): 104 | return self._max_frames 105 | 106 | 107 | class MetaData(object): 108 | def __init__(self, cfg): 109 | self._cfg = cfg 110 | self._query_ids = None 111 | self._labeled_ids = None 112 | self._unlabeled_ids = None 113 | self._train_groundtruth = None 114 | self._test_groundtruth = None 115 | self._test_query_ids = None 116 | self._test_labeled_ids = None 117 | self._train_groups = None 118 | self._train_pairs = None 119 | self._train_ids = None 120 | self._frm_cnt = None 121 | 122 | def _load_ids(self, id_file): 123 | ids = list() 124 | assert os.path.exists(id_file), f"file {id_file} does not exist!" 125 | with open(id_file, 'r') as f: 126 | for l in f: 127 | l = l.strip().replace('.mp4', '') 128 | ids.append(l) 129 | return ids 130 | 131 | def _load_groundtruth(self, groundtruth_file): 132 | queries = list() 133 | labeled = list() 134 | gdtruth = defaultdict(dict) 135 | with open(groundtruth_file, 'r') as f: 136 | for l in f: 137 | l = l.strip().split(' ') 138 | qid = l[0].replace('.mp4', '') 139 | cid = l[1].replace('.mp4', '') 140 | gt = int(l[2]) 141 | gdtruth[qid][cid] = gt 142 | 143 | if qid not in queries: 144 | queries.append(qid) 145 | if cid not in labeled: 146 | labeled.append(cid) 147 | return queries, labeled, gdtruth 148 | 149 | def _load_frame_cnts(self, frame_count_file): 150 | frame_cnts = dict() 151 | with open(frame_count_file, "r") as f: 152 | for l in f: 153 | l = l.strip().split(" ") 154 | frame_cnts[l[0]] = int(l[1]) 155 | 156 | return frame_cnts 157 | 158 | @property 159 | def query_ids(self): 160 | if self._query_ids is None: 161 | self._query_ids = self._load_ids(self._cfg['query_id']) 162 | return self._query_ids 163 | 164 | @property 165 | def labeled_ids(self): 166 | if self._labeled_ids is None: 167 | self._labeled_ids = self._load_ids(self._cfg['labeled_id']) 168 | return self._labeled_ids 169 | 170 | @property 171 | def unlabeled_ids(self): 172 | if self._unlabeled_ids is None: 173 | self._unlabeled_ids = self._load_ids(self._cfg['unlabeled_id']) 174 | return self._unlabeled_ids 175 | 176 | @property 177 | def all_video_ids(self): 178 | if self._query_ids is None: 179 | self._query_ids = self._load_ids(self._cfg['query_id']) 180 | if self._labeled_ids is None: 181 | self._labeled_ids = self._load_ids(self._cfg['labeled_id']) 182 | if self._unlabeled_ids is None: 183 | self._unlabeled_ids = self._load_ids(self._cfg['unlabeled_id']) 184 | return (self._query_ids + self._labeled_ids + self._unlabeled_ids) 185 | 186 | @property 187 | def test_groundtruth(self): 188 | if self._test_groundtruth is None: 189 | self._test_query_ids, self._test_labeled_ids, self._test_groundtruth = \ 190 | self._load_groundtruth(self._cfg['test_groundtruth']) 191 | return self._test_groundtruth 192 | 193 | @property 194 | def test_query_ids(self): 195 | if self._test_query_ids is None: 196 | self._test_query_ids, self._test_labeled_ids, self._test_groundtruth = \ 197 | self._load_groundtruth(self._cfg['test_groundtruth']) 198 | return self._test_query_ids 199 | 200 | @property 201 | def test_labeled_ids(self): 202 | if self._test_labeled_ids is None: 203 | self._test_query_ids, self._test_labeled_ids, self._test_groundtruth = \ 204 | self._load_groundtruth(self._cfg['test_groundtruth']) 205 | return self._test_labeled_ids 206 | 207 | @property 208 | def train_groundtruth(self): 209 | if self._train_groundtruth is None: 210 | _, _, self._train_groundtruth = self._load_groundtruth(self._cfg['train_groundtruth']) 211 | return self._train_groundtruth 212 | 213 | @property 214 | def train_groups(self): 215 | """ 216 | train_groups should be deprecated in the future: pair matching has no transitivity. 217 | """ 218 | if self._train_groups is not None: 219 | return self._train_groups 220 | 221 | groundtruth = self.train_groundtruth 222 | self._train_groups = list() 223 | for qid, cdict in groundtruth.items(): 224 | group = None 225 | for g in self._train_groups: 226 | if qid in g: 227 | group = g 228 | break 229 | for cid, isp in cdict.items(): 230 | if isp: 231 | for g in self._train_groups: 232 | if cid in g: 233 | group = g 234 | break 235 | if group is None: 236 | group = set() 237 | self._train_groups.append(group) 238 | group.add(qid) 239 | for cid, isp in cdict.items(): 240 | if isp: 241 | group.add(cid) 242 | 243 | return self._train_groups 244 | 245 | @property 246 | def train_pairs(self): 247 | if self._train_pairs is not None: 248 | return self._train_pairs 249 | 250 | groundtruth = self.train_groundtruth 251 | self._train_pairs = list() 252 | for qid, cdict in groundtruth.items(): 253 | for cid, isp in cdict.items(): 254 | if isp: 255 | self._train_pairs.append((qid, cid)) 256 | 257 | return self._train_pairs 258 | 259 | @property 260 | def train_ids(self): 261 | if self._train_ids is not None: 262 | return self._train_ids 263 | 264 | self._train_ids = list() 265 | for qid, cdict in self.train_groundtruth.items(): 266 | if qid not in self._train_ids: 267 | self._train_ids.append(qid) 268 | for cid, _ in cdict.items(): 269 | if cid not in self._train_ids: 270 | self._train_ids.append(cid) 271 | 272 | return self._train_ids 273 | 274 | @property 275 | def frm_cnt(self): 276 | if self._frm_cnt is None: 277 | self._frm_cnt = self._load_frame_cnts(self._cfg["frame_count_file"]) 278 | return self._frm_cnt 279 | 280 | @property 281 | def frm_root_path(self): 282 | return self._cfg["frame_root_path"] 283 | -------------------------------------------------------------------------------- /models/delg/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ResNe(X)t models.""" 7 | 8 | import os 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | from .config import cfg 13 | 14 | 15 | # Stage depths for ImageNet models 16 | _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} 17 | 18 | 19 | def get_trans_fun(name): 20 | """Retrieves the transformation function by name.""" 21 | trans_funs = { 22 | "basic_transform": BasicTransform, 23 | "bottleneck_transform": BottleneckTransform, 24 | } 25 | err_str = "Transformation function '{}' not supported" 26 | assert name in trans_funs.keys(), err_str.format(name) 27 | return trans_funs[name] 28 | 29 | 30 | class ResHead(nn.Module): 31 | """ResNet head: AvgPool, 1x1.""" 32 | 33 | def __init__(self, w_in, nc): 34 | super(ResHead, self).__init__() 35 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 36 | self.fc = nn.Linear(w_in, nc, bias=True) 37 | 38 | def forward(self, x): 39 | x = self.avg_pool(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.fc(x) 42 | return x 43 | 44 | 45 | class GlobalHead(nn.Module): 46 | def __init__(self, w_in, nc): 47 | super(GlobalHead, self).__init__() 48 | self.pool = GeneralizedMeanPoolingP() 49 | self.fc = nn.Linear(w_in, nc, bias=True) 50 | 51 | def forward(self, x): 52 | x = self.pool(x) 53 | x = x.view(x.size(0), -1) 54 | x = self.fc(x) 55 | return x 56 | 57 | 58 | class GeneralizedMeanPooling(nn.Module): 59 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 60 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 61 | - At p = infinity, one gets Max Pooling 62 | - At p = 1, one gets Average Pooling 63 | The output is of size H x W, for any input size. 64 | The number of output features is equal to the number of input planes. 65 | Args: 66 | output_size: the target output size of the image of the form H x W. 67 | Can be a tuple (H, W) or a single H for a square image H x H 68 | H and W can be either a ``int``, or ``None`` which means the size will 69 | be the same as that of the input. 70 | """ 71 | 72 | def __init__(self, norm, output_size=1, eps=1e-6): 73 | super(GeneralizedMeanPooling, self).__init__() 74 | assert norm > 0 75 | self.p = float(norm) 76 | self.output_size = output_size 77 | self.eps = eps 78 | 79 | def forward(self, x): 80 | x = x.clamp(min=self.eps).pow(self.p) 81 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 82 | 83 | def __repr__(self): 84 | return self.__class__.__name__ + '(' \ 85 | + str(self.p) + ', ' \ 86 | + 'output_size=' + str(self.output_size) + ')' 87 | 88 | 89 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 90 | """ Same, but norm is trainable 91 | """ 92 | 93 | def __init__(self, norm=3, output_size=1, eps=1e-6): 94 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 95 | self.p = nn.Parameter(torch.ones(1) * norm) 96 | 97 | 98 | class BasicTransform(nn.Module): 99 | """Basic transformation: 3x3, BN, ReLU, 3x3, BN.""" 100 | 101 | def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): 102 | err_str = "Basic transform does not support w_b and num_gs options" 103 | assert w_b is None and num_gs == 1, err_str 104 | super(BasicTransform, self).__init__() 105 | self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) 106 | self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 107 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 108 | self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) 109 | self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 110 | self.b_bn.final_bn = True 111 | 112 | def forward(self, x): 113 | for layer in self.children(): 114 | x = layer(x) 115 | return x 116 | 117 | 118 | class BottleneckTransform(nn.Module): 119 | """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN.""" 120 | 121 | def __init__(self, w_in, w_out, stride, w_b, num_gs): 122 | super(BottleneckTransform, self).__init__() 123 | # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 124 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) 125 | self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False) 126 | self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 127 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 128 | self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False) 129 | self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 130 | self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 131 | self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) 132 | self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 133 | self.c_bn.final_bn = True 134 | 135 | def forward(self, x): 136 | for layer in self.children(): 137 | x = layer(x) 138 | return x 139 | 140 | 141 | class ResBlock(nn.Module): 142 | """Residual block: x + F(x).""" 143 | 144 | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): 145 | super(ResBlock, self).__init__() 146 | # Use skip connection with projection if shape changes 147 | self.proj_block = (w_in != w_out) or (stride != 1) 148 | if self.proj_block: 149 | self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) 150 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 151 | self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) 152 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 153 | 154 | def forward(self, x): 155 | if self.proj_block: 156 | x = self.bn(self.proj(x)) + self.f(x) 157 | else: 158 | x = x + self.f(x) 159 | x = self.relu(x) 160 | return x 161 | 162 | 163 | class ResStage(nn.Module): 164 | """Stage of ResNet.""" 165 | 166 | def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): 167 | super(ResStage, self).__init__() 168 | for i in range(d): 169 | b_stride = stride if i == 0 else 1 170 | b_w_in = w_in if i == 0 else w_out 171 | trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) 172 | res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) 173 | self.add_module("b{}".format(i + 1), res_block) 174 | 175 | def forward(self, x): 176 | for block in self.children(): 177 | x = block(x) 178 | return x 179 | 180 | @staticmethod 181 | def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1): 182 | for i in range(d): 183 | b_stride = stride if i == 0 else 1 184 | b_w_in = w_in if i == 0 else w_out 185 | trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN) 186 | cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs) 187 | return cx 188 | 189 | 190 | class ResStemIN(nn.Module): 191 | """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" 192 | 193 | def __init__(self, w_in, w_out): 194 | super(ResStemIN, self).__init__() 195 | self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) 196 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 197 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 198 | self.pool = nn.MaxPool2d(3, stride=2, padding=1) 199 | 200 | def forward(self, x): 201 | for layer in self.children(): 202 | x = layer(x) 203 | return x 204 | 205 | 206 | def init_weights(m): 207 | """Performs ResNet-style weight initialization.""" 208 | if isinstance(m, nn.Conv2d): 209 | # Note that there is no bias due to BN 210 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 211 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 212 | elif isinstance(m, nn.BatchNorm2d): 213 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA 214 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 215 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 216 | m.bias.data.zero_() 217 | elif isinstance(m, nn.Linear): 218 | m.weight.data.normal_(mean=0.0, std=0.01) 219 | m.bias.data.zero_() 220 | 221 | 222 | class ResNet(nn.Module): 223 | """ResNet model.""" 224 | 225 | def __init__(self): 226 | super(ResNet, self).__init__() 227 | self._construct() 228 | self.apply(init_weights) 229 | 230 | def _construct(self): 231 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP 232 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] 233 | w_b = gw * g 234 | self.stem = ResStemIN(3, 64) 235 | self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g) 236 | self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g) 237 | self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g) 238 | self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g) 239 | #self.head = ResHead(2048, nc=cfg.MODEL.HEADS.REDUCTION_DIM) 240 | self.head = GlobalHead(2048, nc=cfg.MODEL.HEADS.REDUCTION_DIM) 241 | 242 | def forward(self, x): 243 | x = self.stem(x) 244 | x1 = self.s1(x) 245 | x2 = self.s2(x1) 246 | x3 = self.s3(x2) 247 | x4 = self.s4(x3) 248 | x = self.head(x4) 249 | return x, x3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /datasets/svd/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | from collections import defaultdict, OrderedDict 8 | from metric import APScorer, SmoothedValue, MetricLogger 9 | from utils import get_rank, get_world_size 10 | 11 | 12 | __all__ = [ 13 | "DistEvaluator", 14 | "MultiTeacherEvaluator" 15 | ] 16 | 17 | 18 | class DistEvaluator(object): 19 | def __init__(self, device, test_groundtruth, topk, max_frames): 20 | self.rank = get_rank() 21 | self.world_size = get_world_size() 22 | self.device = device 23 | 24 | self.gdtruth = test_groundtruth 25 | self.ranks = defaultdict(list) 26 | 27 | self.max_frames = max_frames 28 | 29 | if self.rank == 0: 30 | self.topk = topk 31 | self.scorer = APScorer() 32 | self.res = OrderedDict() 33 | for k in self.topk: 34 | self.res["top-"+str(k)] = SmoothedValue(window_size=None) 35 | self.res["top-inf"] = SmoothedValue(window_size=None) 36 | self.logger = logging.getLogger("svd.dist_eval."+str(self.rank)) 37 | 38 | @torch.no_grad() 39 | def __call__(self, 40 | model, query_loader, labeled_loader, unlabeled_loader, 41 | full_eval, sim_fn, dump_to=None 42 | ): 43 | 44 | if self.rank == 0: 45 | self.logger.info("Start evaluation.") 46 | self.logger.info("Processing Query Samples...") 47 | 48 | dist.barrier() 49 | 50 | metric_logger = MetricLogger(delimiter=" ", logger=self.logger if self.rank == 0 else None) 51 | 52 | qfeats = [] 53 | qlens = [] 54 | qids = [] 55 | 56 | for idx, batch in metric_logger.log_every(query_loader, log_freq=16): 57 | feats, lens, ids = self.forward_query(model, batch, sim_fn) 58 | 59 | qfeats.append(feats) 60 | qlens.append(lens) 61 | qids += ids 62 | 63 | qfeats = torch.cat(qfeats, dim=0) 64 | qlens = torch.cat(qlens, dim=0) 65 | self.gather_query(qfeats, qlens, qids) 66 | 67 | if self.rank == 0: 68 | self.logger.info(f"Query gathered to all workers, totally {len(self.qfeats)} queries.") 69 | 70 | self.logger.info("Processing Labaled Samples.") 71 | 72 | for idx, batch in metric_logger.log_every(labeled_loader, log_freq=16): 73 | feats, lens, ids = self.forward_labeled(model, batch, sim_fn) 74 | sims = self.cal_sim(feats, lens, sim_fn) 75 | 76 | self.handle_labeled(sims, ids) 77 | 78 | if full_eval: 79 | 80 | if self.rank == 0: 81 | self.logger.info("Processing Features of UnLabaled Samples.") 82 | 83 | for idx, batch in metric_logger.log_every(unlabeled_loader, log_freq=1024): 84 | feats, lens, ids = self.forward_unlabeled(model, batch, sim_fn) 85 | sims = self.cal_sim(feats, lens, sim_fn) 86 | 87 | self.handle_unlabeled(sims, ids) 88 | 89 | self.sync_ranks() 90 | 91 | # eval result 92 | if self.rank == 0: 93 | evalr = self.score() 94 | self.logger.info(" | ".join([f"{k} mAP: {v:.4f}" for k, v in evalr.items()])) 95 | 96 | if dump_to is not None: 97 | # dump evaluation result to file 98 | self.dump(dump_to) 99 | 100 | return evalr 101 | else: 102 | return None 103 | 104 | def forward_batch(self, model, batch, sim_fn): 105 | frames, n_frames, ids = batch 106 | frames = frames.to(self.device) 107 | n_frames = n_frames.to(self.device) 108 | 109 | frames = model(frames) 110 | 111 | if sim_fn == "fme" or sim_fn == "fmx": 112 | _frames= [] 113 | i = 0 114 | for nf in n_frames: 115 | if sim_fn == "fme": 116 | _frames.append(frames[i:i+nf].mean(dim=0)) 117 | elif sim_fn == "fmx": 118 | _frames.append(frames[i:i+nf].max(dim=0)[0]) 119 | i += nf 120 | frames = torch.stack(_frames, dim=0) 121 | elif sim_fn == "sme" or sim_fn == "smx" or sim_fn == "cf": 122 | _frames = torch.zeros( 123 | len(n_frames), self.max_frames, frames.size(1), 124 | dtype=frames.dtype, device=frames.device 125 | ) 126 | s = 0 127 | for i, nf in enumerate(n_frames): 128 | _frames[i][:nf] = frames[s:s+nf] 129 | s += nf 130 | frames = _frames 131 | else: 132 | raise NotImplementedError(f"{sim_fn} not implemented.") 133 | 134 | frames = nn.functional.normalize(frames, p=2, dim=-1) 135 | 136 | return frames, n_frames, ids 137 | 138 | def forward_query(self, model, batch, sim_fn): 139 | return self.forward_batch(model, batch, sim_fn) 140 | 141 | def forward_labeled(self, model, batch, sim_fn): 142 | return self.forward_batch(model, batch, sim_fn) 143 | 144 | def forward_unlabeled(self, model, batch, sim_fn): 145 | return self.forward_batch(model, batch, sim_fn) 146 | 147 | def cal_sim(self, feats, lens, sim_fn): 148 | if sim_fn == "fme" or sim_fn == "fmx": 149 | return self.qfeats.mm(feats.t()) 150 | 151 | # Q x C x F x F 152 | sims = self.qfeats.unsqueeze(1).matmul(feats.transpose(1, 2)) 153 | mask = torch.ones_like(sims, dtype=torch.bool) 154 | for i in range(mask.size(0)): 155 | for j in range(mask.size(1)): 156 | mask[i, j, self.qlens[i]:, :] = False 157 | mask[i, j, :, lens[j]:] = False 158 | 159 | if sim_fn == "cf": 160 | sims = self.cal_sim_chamfer(sims, mask) 161 | elif sim_fn == "sme": 162 | sims = self.cal_sim_mean(sims, mask) 163 | elif sim_fn == "smx": 164 | sims = self.cal_sim_max(sims, mask) 165 | else: 166 | raise NotImplementedError(f"{sim_fn} is not implemented.") 167 | 168 | return sims 169 | 170 | def cal_sim_chamfer(self, sim_mat, mask): 171 | sim_mat.masked_fill_(mask.logical_not(), float("-inf")) 172 | sim_mat = sim_mat.max(dim=-1)[0] 173 | is_inf = sim_mat.isinf() 174 | sim_mat = sim_mat.masked_fill_(is_inf, 0) 175 | sim_mat = sim_mat.sum(dim=-1).div(is_inf.logical_not().sum(dim=-1)) 176 | 177 | return sim_mat 178 | 179 | def cal_sim_max(self, sim_mat, mask): 180 | sim_mat.masked_fill_(mask.logical_not(), float("-inf")) 181 | sim_mat = sim_mat.flatten(2).max(dim=-1)[0] 182 | return sim_mat 183 | 184 | def cal_sim_mean(self, sim_mat, mask): 185 | sim_mat.masked_fill_(mask.logical_not(), 0) 186 | sim_mat = sim_mat.sum(dim=(2,3)).div(mask.sum(dim=(2,3))) 187 | return sim_mat 188 | 189 | def gather_query(self, feats, lens, ids): 190 | num_queries_gather = [torch.tensor(0, dtype=torch.long, device=self.device) for _ in range(self.world_size)] 191 | dist.all_gather(num_queries_gather, torch.tensor(feats.size(0), dtype=torch.long, device=self.device), async_op=False) 192 | 193 | self.qfeats = [ 194 | torch.zeros([num_queries_gather[i],*feats.size()[1:]], dtype=feats.dtype, device=self.device) 195 | for i in range(self.world_size) 196 | ] 197 | self.qfeats[self.rank] = feats 198 | for i in range(self.world_size): 199 | dist.broadcast(self.qfeats[i], src=i, async_op=False) 200 | self.qfeats = torch.cat(self.qfeats, dim=0) 201 | 202 | self.qlens = [ 203 | torch.zeros(num_queries_gather[i], dtype=lens.dtype, device=self.device) 204 | for i in range(self.world_size) 205 | ] 206 | self.qlens[self.rank] = lens 207 | for i in range(self.world_size): 208 | dist.broadcast(self.qlens[i], src=i, async_op=False) 209 | self.qlens = torch.cat(self.qlens, dim=0) 210 | 211 | ids_gather = [None for _ in range(self.world_size)] 212 | dist.all_gather_object(ids_gather, ids) 213 | self.qids = sum(ids_gather, []) 214 | 215 | def sync_ranks(self): 216 | for qid in self.qids: 217 | ranks_gather = [None for _ in range(self.world_size)] 218 | if self.rank == 0: 219 | dist.gather_object(self.ranks[qid], object_gather_list=ranks_gather, dst=0) 220 | self.ranks[qid] = sum(ranks_gather, []) 221 | else: 222 | dist.gather_object(self.ranks[qid], object_gather_list=None, dst=0) 223 | dist.barrier() 224 | 225 | def handle_labeled(self, sims, ids): 226 | sims = sims.cpu().tolist() 227 | for i, qid in enumerate(self.qids): 228 | sim = [] 229 | cids = [] 230 | for j, cid in enumerate(ids): 231 | if cid in self.gdtruth[qid]: 232 | sim.append(sims[i][j]) 233 | cids.append(cid) 234 | self.ranks[qid] += list(zip(sim, cids, [self.gdtruth[qid][cid] for cid in cids])) 235 | 236 | def handle_unlabeled(self, sims, ids): 237 | sims = sims.cpu().tolist() 238 | for i, qid in enumerate(self.qids): 239 | self.ranks[qid] += list(zip(sims[i], ids, [0]*len(ids))) 240 | 241 | def score(self): 242 | self.aps = defaultdict(OrderedDict) 243 | for qid in self.qids: 244 | self.ranks[qid].sort(key=lambda x: x[0], reverse=True) 245 | for k in self.topk: 246 | sorted_labels = [] 247 | for i in self.ranks[qid][:k]: 248 | sorted_labels.append(i[2]) 249 | ap = self.scorer.score(sorted_labels) 250 | self.aps[qid]["top-"+str(k)] = ap 251 | self.res["top-"+str(k)].update(ap) 252 | 253 | sorted_labels = [] 254 | for i in self.ranks[qid]: 255 | sorted_labels.append(i[2]) 256 | ap = self.scorer.score(sorted_labels) 257 | self.aps[qid]["top-inf"] = ap 258 | self.res["top-inf"].update(ap) 259 | 260 | return {k: v.avg for k, v in self.res.items()} 261 | 262 | def dump(self, dump_to, topk=100): 263 | record = list() 264 | for qid in self.qids: 265 | item = {'qid': qid, 'ap': self.aps[qid], 'ranking': [], 'positive': []} 266 | for score, vid, label in self.ranks[qid][:topk]: 267 | d = { 268 | 'score': score, 269 | 'id': vid, 270 | 'label': label 271 | } 272 | item['ranking'].append(d) 273 | 274 | pos = [cid for cid, ispos in self.gdtruth[qid].items() if ispos] 275 | for i, (score, vid, _) in enumerate(self.ranks[qid]): 276 | if vid in pos: 277 | item['positive'].append( 278 | { 279 | "id": vid, 280 | "rank": i, 281 | "score": score 282 | } 283 | ) 284 | record.append(item) 285 | record.sort(key=lambda x: x['ap']['top-inf']) 286 | os.makedirs(dump_to, exist_ok=True) 287 | with open(os.path.join(dump_to, "ret_res.json"), 'w') as f: 288 | f.write(json.dumps(record, sort_keys=True, indent=4)) 289 | 290 | 291 | class MultiTeacherEvaluator(DistEvaluator): 292 | def forward_batch(self, model, batch, *args, **kwargs): 293 | frames, lens, ids = batch 294 | frames = frames.to(self.device) 295 | lens = lens.to(self.device) 296 | 297 | frames = model(frames, *args, **kwargs) 298 | for i in range(len(frames)): 299 | _frames= [] 300 | s = 0 301 | for l in lens: 302 | _frames.append(frames[i][s:s+l].mean(dim=0)) 303 | # _frames.append(frames[i][s:s+l].max(dim=0)[0]) 304 | s += l 305 | frames[i] = torch.stack(_frames, dim=0) 306 | frames[i] = nn.functional.normalize(frames[i], p=2, dim=-1) 307 | 308 | frames = torch.cat(frames, dim=-1) 309 | 310 | return frames, lens, ids 311 | 312 | 313 | class FinetuneEvaluator(DistEvaluator): 314 | def forward_batch(self, model, batch, sim_fn): 315 | x, n_frames, ids = batch 316 | x = x.to(self.device) 317 | n_frames = n_frames.to(self.device) 318 | 319 | x = model(x, n_frames) 320 | 321 | if sim_fn == "fme" or sim_fn == "fmx": 322 | _x = [] 323 | for i in range(x.size(0)): 324 | if sim_fn == "fme": 325 | _x.append(x[i][:n_frames[i]].mean(dim=0)) 326 | elif sim_fn == "fmx": 327 | _x.append(x[i][:n_frames[i]].max(dim=0)[0]) 328 | x = torch.stack(_x, dim=0) 329 | 330 | x = nn.functional.normalize(x, p=2, dim=-1) 331 | 332 | return x, n_frames, ids -------------------------------------------------------------------------------- /oxford_paris_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import logging 8 | import numpy as np 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | import torch.nn as nn 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | 18 | import datasets.oxford_paris as oxford_paris 19 | from models.distill_model import archs 20 | from models.teachers import teacher_models 21 | from models.gem_pooling import GeneralizedMeanPooling 22 | from models.pca_layer import pca_layers 23 | import utils 24 | import metric 25 | 26 | 27 | def parse_args(): 28 | """ 29 | Parse input arguments 30 | """ 31 | parser = argparse.ArgumentParser(description="Evaluation on roxford5k or rparis6k.") 32 | parser.add_argument( 33 | "-dp", "--data_path", default="/path/to/datasets", type=str, 34 | help="root path to dataset" 35 | ) 36 | parser.add_argument( 37 | "-d", "--dataset", default="roxford5k", type=str, 38 | choices=["roxford5k", "rparis6k"], help="dataset name" 39 | ) 40 | parser.add_argument( 41 | "-a", "--arch", default=None, type=str, metavar="ARCH", 42 | choices=list(archs.keys()), help="architecture of backbone model" 43 | ) 44 | parser.add_argument( 45 | "-t", "--teacher", default=None, type=str, metavar="ARCH", 46 | choices=list(teacher_models.keys()), help="name of teacher model" 47 | ) 48 | parser.add_argument( 49 | "--imsize", default=None, type=int, 50 | help="input image shape" 51 | ) 52 | parser.add_argument( 53 | "--num_workers", default=8, type=int, metavar="N", 54 | help="number of data loader workers" 55 | ) 56 | parser.add_argument( 57 | "--world_size", default=8, type=int, 58 | help="number of workers per node" 59 | ) 60 | parser.add_argument( 61 | "--dist_url", default="tcp://localhost:2023", type=str, 62 | help='url used to set up distributed evaluation' 63 | ) 64 | parser.add_argument( 65 | '--rank', default=0, type=int, 66 | help='node rank for distributed training' 67 | ) 68 | parser.add_argument( 69 | "-r", "--resume", default=None, type=str, metavar="DIR", 70 | help="checkpoint model to resume" 71 | ) 72 | parser.add_argument( 73 | "--path_to_pretrained_weights", default="/path/to/pretrained_weights", type=str, metavar="DIR", 74 | help="path to pretrained teacher models and whitening weights" 75 | ) 76 | parser.add_argument( 77 | "-ms", "--multiscale", action="store_true", 78 | help="multiscale testing" 79 | ) 80 | parser.add_argument( 81 | '--embed_dim', default=512, type=int, 82 | help='embedding dimension' 83 | ) 84 | parser.add_argument( 85 | '-p', default=3, type=float, 86 | help='power rate' 87 | ) 88 | parser.add_argument( 89 | "--plus1m", action="store_true", 90 | help="plus 1M distractors" 91 | ) 92 | parser.add_argument( 93 | "--pca", action="store_true", 94 | help="whether to perform PCA-Whitening on teacher model" 95 | ) 96 | return parser.parse_args() 97 | 98 | 99 | def main(): 100 | args = parse_args() 101 | 102 | utils.setup_logger() 103 | logger = logging.getLogger("image retrieval") 104 | 105 | logger.info(vars(args)) 106 | 107 | # by default use all available gpus 108 | ngpus_per_node = torch.cuda.device_count() 109 | 110 | # distributed evaluation 111 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 112 | 113 | 114 | def main_worker(gpu, ngpus_per_node, args): 115 | cudnn.benchmark = True 116 | 117 | args.gpu = gpu 118 | 119 | if args.dist_url == "env://" and args.rank == -1: 120 | args.rank = int(os.environ["RANK"]) 121 | args.rank = args.rank * ngpus_per_node + gpu 122 | 123 | utils.setup_logger(log_path=None) 124 | logger = logging.getLogger("worker " + str(args.rank)) 125 | 126 | dist.init_process_group( 127 | backend="nccl", 128 | init_method=args.dist_url, 129 | world_size=args.world_size, 130 | rank=args.rank 131 | ) 132 | 133 | torch.cuda.set_device(args.gpu) 134 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 135 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 136 | 137 | # arguments arch and teacher are mutually exclusive 138 | assert args.arch is None or args.teacher is None 139 | 140 | # build student model 141 | if args.arch is not None: 142 | model = archs[args.arch](pretrained=False, num_classes=args.embed_dim) 143 | model.avgpool = GeneralizedMeanPooling(args.p) 144 | 145 | if args.resume is not None: 146 | checkpoint_file = args.resume 147 | if os.path.isfile(checkpoint_file): 148 | logger.info(f"Loading checkpoint \"{checkpoint_file}\"...") 149 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 150 | else: 151 | logger.error(f"=> No checkpoint found at '{checkpoint_file}'.") 152 | sys.exit() 153 | 154 | state_dict = checkpoint["state_dict"] 155 | for k in list(state_dict.keys()): 156 | if k.startswith('module.base_encoder'): 157 | state_dict[k[len("module.base_encoder."):]] = state_dict.pop(k) 158 | else: 159 | state_dict.pop(k) 160 | model.load_state_dict(state_dict, strict=True) 161 | logger.info(f"Loaded checkpoint.") 162 | 163 | # build teacher model 164 | if args.teacher is not None: 165 | model = teacher_models[args.teacher](args.path_to_pretrained_weights, gem_p=args.p) 166 | 167 | model.cuda(device) 168 | 169 | if utils.has_batchnorms(model): 170 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 171 | 172 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 173 | 174 | model.eval() 175 | 176 | # for teacher model, a pca-whitening layer is optional to be used 177 | pca_wtn = None 178 | if args.teacher is not None and args.pca: 179 | pca_wtn = pca_layers[args.teacher]().cuda(device) 180 | model.module.embed_dim = pca_wtn.dim2 181 | 182 | # no image scale 183 | transform_list = [ 184 | T.ToTensor(), 185 | T.Normalize( 186 | mean=[0.485, 0.456, 0.406], 187 | std=[0.229, 0.224, 0.225] 188 | ) 189 | ] 190 | # DELG and DOLG's input channel is BGR rather than default RGB 191 | if args.teacher is not None and (args.teacher.endswith("delg") or args.teacher.endswith("dolg")): 192 | transform_list.append(utils.RGB2BGR()) 193 | transform = T.Compose(transform_list) 194 | dataset_query = oxford_paris.OxfordParisDataset( 195 | args.data_path, args.dataset, "query", 196 | transform=transform, imsize=args.imsize 197 | ) 198 | dataset_database = oxford_paris.OxfordParisDataset( 199 | args.data_path, args.dataset, "database", 200 | transform=transform, imsize=args.imsize 201 | ) 202 | sampler_query = torch.utils.data.distributed.DistributedSampler( 203 | dataset_query, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=False 204 | ) 205 | sampler_database = torch.utils.data.distributed.DistributedSampler( 206 | dataset_database, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=False 207 | ) 208 | data_loader_query = torch.utils.data.DataLoader( 209 | dataset_query, 210 | batch_size=1, shuffle=False, sampler=sampler_query, 211 | pin_memory=True, num_workers=args.num_workers 212 | ) 213 | data_loader_database = torch.utils.data.DataLoader( 214 | dataset_database, 215 | batch_size=1, shuffle=False, sampler=sampler_database, 216 | pin_memory=True, num_workers=args.num_workers 217 | ) 218 | 219 | logger.info(f"database: {len(dataset_database)} imgs") 220 | logger.info(f"query: {len(dataset_query)} imgs") 221 | 222 | query_features = extract_features( 223 | model, data_loader_query, 224 | pca_wtn=pca_wtn, 225 | multiscale=args.multiscale 226 | ) 227 | database_features = extract_features( 228 | model, data_loader_database, 229 | pca_wtn=pca_wtn, 230 | multiscale=args.multiscale 231 | ) 232 | 233 | if args.plus1m: 234 | dataset_distractor = oxford_paris.OxfordParisDataset( 235 | args.data_path, "revisitop1m", "distractor", 236 | transform=transform, imsize=args.imsize 237 | ) 238 | sampler_distractor = torch.utils.data.distributed.DistributedSampler( 239 | dataset_distractor, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=False 240 | ) 241 | data_loader_distractor = torch.utils.data.DataLoader( 242 | dataset_distractor, 243 | batch_size=1, shuffle=False, sampler=sampler_distractor, 244 | pin_memory=True, num_workers=args.num_workers 245 | ) 246 | logger.info(f"distractor: {len(dataset_distractor)} imgs") 247 | distractor_features = extract_features( 248 | model, data_loader_distractor, 249 | pca_wtn=pca_wtn, multiscale=args.multiscale 250 | ) 251 | 252 | # calculate metrics on the main process 253 | if args.rank == 0: 254 | # Step 1: normalize features 255 | database_features = nn.functional.normalize(database_features, dim=1, p=2) 256 | query_features = nn.functional.normalize(query_features, dim=1, p=2) 257 | if args.plus1m: 258 | distractor_features = nn.functional.normalize(distractor_features, dim=1, p=2) 259 | database_features = torch.cat([database_features, distractor_features], dim=0) 260 | 261 | ############################################################################ 262 | # Step 2: similarity 263 | sim = torch.mm(database_features, query_features.T) 264 | ranks = torch.argsort(-sim, dim=0).cpu().numpy() 265 | 266 | ############################################################################ 267 | # Step 3: evaluate 268 | gnd = dataset_database.cfg['gnd'] 269 | # evaluate ranks 270 | ks = [1, 5, 10] 271 | # search for easy & hard 272 | gnd_t = [] 273 | for i in range(len(gnd)): 274 | g = {} 275 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 276 | g['junk'] = np.concatenate([gnd[i]['junk']]) 277 | gnd_t.append(g) 278 | mapM, apsM, mprM, prsM = oxford_paris.compute_map(ranks, gnd_t, ks) 279 | # search for hard 280 | gnd_t = [] 281 | for i in range(len(gnd)): 282 | g = {} 283 | g['ok'] = np.concatenate([gnd[i]['hard']]) 284 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 285 | gnd_t.append(g) 286 | mapH, apsH, mprH, prsH = oxford_paris.compute_map(ranks, gnd_t, ks) 287 | logger.info('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) 288 | logger.info('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) 289 | 290 | dist.barrier() 291 | 292 | 293 | @torch.no_grad() 294 | def extract_features(model, data_loader, pca_wtn=None, multiscale=False): 295 | rank = utils.get_rank() 296 | if rank == 0: 297 | logger = logging.getLogger("extract_feature") 298 | else: 299 | logger = None 300 | metric_logger = metric.MetricLogger(logger, delimiter=" ") 301 | log_freq = len(data_loader) // 16 if len(data_loader) >= 16 else len(data_loader) 302 | if rank == 0: 303 | features = torch.zeros(len(data_loader.dataset), model.module.embed_dim) 304 | for batch_idx, (samples, index) in metric_logger.log_every(data_loader, log_freq): 305 | samples = samples.cuda(non_blocking=True) 306 | index = index.cuda(non_blocking=True) 307 | 308 | if multiscale: 309 | feats = multiscale_feature(samples, model, pca_wtn=pca_wtn) 310 | else: 311 | feats = model(samples) 312 | if pca_wtn is not None: 313 | feats = nn.functional.normalize(feats, p=2, dim=-1) 314 | feats = pca_wtn(feats) 315 | 316 | index_all = gather_to_main(index) 317 | feats_all = gather_to_main(feats) 318 | 319 | if rank == 0: 320 | features.index_copy_(0, index_all.cpu(), feats_all.cpu()) 321 | 322 | if rank == 0: 323 | return features 324 | else: 325 | return None 326 | 327 | 328 | def gather_to_main(x): 329 | world_size = utils.get_world_size() 330 | rank = utils.get_rank() 331 | 332 | if rank == 0: 333 | gather_x = [torch.zeros_like(x) for _ in range(world_size)] 334 | else: 335 | gather_x = None 336 | dist.gather(x, gather_x if rank == 0 else None, dst=0) 337 | 338 | if rank == 0: 339 | gather_x = torch.cat(gather_x, dim=0) 340 | 341 | return gather_x 342 | 343 | 344 | def multiscale_feature(samples, model, pca_wtn=None): 345 | scale = [1, 1/2**(1/2), 1/2] 346 | v = None 347 | for s in scale: 348 | if s == 1: 349 | inp = samples 350 | else: 351 | inp = nn.functional.interpolate(samples, scale_factor=s, mode="bilinear", align_corners=False) 352 | feats = model(inp) 353 | feats = nn.functional.normalize(feats, p=2, dim=-1) 354 | if pca_wtn is not None: 355 | feats = pca_wtn(feats) 356 | feats = nn.functional.normalize(feats, p=2, dim=-1) 357 | 358 | if v is None: 359 | v = feats 360 | else: 361 | v += feats 362 | v /= len(scale) 363 | return v 364 | 365 | 366 | if __name__ == "__main__": 367 | main() 368 | -------------------------------------------------------------------------------- /svd_distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | import yaml 6 | import random 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | import torch.nn as nn 13 | import torchvision.transforms as T 14 | from PIL import ImageFilter, ImageOps 15 | 16 | from models import distill_model 17 | from datasets.svd.core import Frame, MetaData 18 | from datasets.svd.loader import DistributedDistillLoader 19 | from loss import MultiTeacherDistillLoss 20 | import utils 21 | import metric 22 | 23 | 24 | def parse_args(): 25 | """ 26 | Parse input arguments 27 | """ 28 | parser = argparse.ArgumentParser(description="Knowledge Distill on SVD dataset.") 29 | parser.add_argument( 30 | "-a", "--arch", default="mobilenet_v2", type=str, metavar="ARCH", 31 | choices=list(distill_model.archs.keys()), help="name of backbone model" 32 | ) 33 | parser.add_argument( 34 | "-ts", "--teachers", required=True, nargs="+", metavar="LIST", 35 | help="teacher models to distill" 36 | ) 37 | parser.add_argument( 38 | "-c", "--ckpt_dir", required=True, type=str, metavar="DIR", 39 | help="directory to save checkpoint files" 40 | ) 41 | parser.add_argument( 42 | "-dm", "--dataset_meta", default="config/svd.yaml", type=str, metavar="FILE", 43 | help="dataset meta file" 44 | ) 45 | parser.add_argument( 46 | "-b", "--batch_size", default=64, type=int, metavar="N", 47 | help="training batch size" 48 | ) 49 | parser.add_argument( 50 | "--num_workers", default=8, type=int, metavar="N", 51 | help="number of data loader workers" 52 | ) 53 | parser.add_argument( 54 | "-r", "--resume", default=None, type=str, metavar="DIR", 55 | help="checkpoint model to resume" 56 | ) 57 | parser.add_argument( 58 | "--path_to_pretrained_weights", default="/path/to/pretrained_weights", type=str, metavar="DIR", 59 | help="path to pretrained teacher models and whitening weights" 60 | ) 61 | parser.add_argument( 62 | "--world_size", default=8, type=int, 63 | help="number of workers" 64 | ) 65 | parser.add_argument( 66 | "--dist_url", default="tcp://localhost:2023", type=str, 67 | help='url used to set up distributed training' 68 | ) 69 | parser.add_argument( 70 | '--rank', default=0, type=int, 71 | help='node rank for distributed training' 72 | ) 73 | parser.add_argument( 74 | "--start_epoch", type=int, metavar="N", default=0, 75 | help="start from epoch i" 76 | ) 77 | parser.add_argument( 78 | "--epochs", default=200, type=int, metavar="N", 79 | help="training epochs" 80 | ) 81 | parser.add_argument( 82 | "--warmup_epochs", default=5, type=int, metavar="N", 83 | help="number of warmup epochs" 84 | ) 85 | parser.add_argument( 86 | "--lr", default=0.3, type=float, metavar="N", 87 | help="initial learning rate" 88 | ) 89 | parser.add_argument( 90 | "-t", default=0.05, type=float, metavar="N", 91 | help="temprature rate of distillation loss" 92 | ) 93 | parser.add_argument( 94 | '--wd', '--weight-decay', default=1e-6, type=float, metavar='W', dest='weight_decay', 95 | help='weight decay (default: 1e-6)' 96 | ) 97 | parser.add_argument( 98 | "--snapshot_step", default=10, type=int, metavar="N", 99 | help="interval to dump checkpoint" 100 | ) 101 | parser.add_argument( 102 | '-p', default=1, type=float, 103 | help='power rate' 104 | ) 105 | parser.add_argument( 106 | '--embed_dim', default=512, type=int, 107 | help='embedding dimension' 108 | ) 109 | parser.add_argument( 110 | "-s", "--strategy", default="maxmin", type=str, 111 | help="similarity fusion strategy" 112 | ) 113 | return parser.parse_args() 114 | 115 | 116 | def main(): 117 | args = parse_args() 118 | os.makedirs(args.ckpt_dir, exist_ok=True) 119 | with open(os.path.join(args.ckpt_dir, "config.yaml"), "w") as f: 120 | yaml.dump(vars(args), f) 121 | 122 | utils.setup_logger() 123 | logger = logging.getLogger("svd_distill") 124 | 125 | logger.info(vars(args)) 126 | 127 | ngpus_per_node = torch.cuda.device_count() 128 | 129 | mp.spawn(main_worker, nprocs=8, args=(ngpus_per_node, args)) 130 | 131 | 132 | def main_worker(gpu, ngpus_per_node, args): 133 | cudnn.benchmark = True 134 | 135 | args.gpu = gpu 136 | 137 | if args.dist_url == "env://" and args.rank == -1: 138 | args.rank = int(os.environ["RANK"]) 139 | args.rank = args.rank * ngpus_per_node + gpu 140 | 141 | if args.rank == 0: 142 | utils.setup_logger(args.ckpt_dir) 143 | else: 144 | utils.setup_logger(log_path=None) 145 | logger = logging.getLogger("dist_worker " + str(args.rank)) 146 | 147 | dist.init_process_group( 148 | backend="nccl", 149 | init_method=args.dist_url, 150 | world_size=args.world_size, 151 | rank=args.rank 152 | ) 153 | 154 | torch.cuda.set_device(args.gpu) 155 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 156 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 157 | 158 | dataset_meta_cfg = utils.load_config(args.dataset_meta) 159 | dataset_meta = MetaData(dataset_meta_cfg) 160 | 161 | train_dataset = Frame( 162 | dataset_meta.frm_root_path, 163 | dataset_meta.frm_cnt, 164 | TwoCropsTransform() 165 | ) 166 | train_loader = DistributedDistillLoader.build( 167 | train_dataset, dataset_meta, 168 | batch_size=args.batch_size, num_workers=args.num_workers 169 | ) 170 | 171 | model = distill_model.CrossViewDistillModel(args) 172 | logger.info(f"Build model {model.__class__.__name__}") 173 | n_parameters = sum([p.data.nelement() for p in model.parameters()]) 174 | n_trainable_parameters = sum([p.data.nelement() for p in model.parameters() if p.requires_grad]) 175 | logger.info(f"Number of parameters: {n_parameters}") 176 | logger.info(f"Number of trainable parameters: {n_trainable_parameters}") 177 | 178 | model.cuda(device) 179 | 180 | if utils.has_batchnorms(model): 181 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 182 | 183 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 184 | 185 | # args.lr = args.lr * args.batch_size * args.world_size / 256 186 | 187 | if args.arch.startswith("vit"): 188 | optimizer = torch.optim.AdamW( 189 | filter(lambda p: p.requires_grad, model.parameters()), 190 | lr=args.lr 191 | ) 192 | else: 193 | # optimizer = LARS( 194 | # filter(lambda p: p.requires_grad, model.parameters()), 195 | # lr=args.lr, 196 | # weight_decay=args.weight_decay 197 | # ) 198 | optimizer = torch.optim.Adam( 199 | filter(lambda p: p.requires_grad, model.parameters()), 200 | lr=args.lr, 201 | weight_decay=args.weight_decay 202 | ) 203 | 204 | # warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( 205 | # optimizer, 206 | # start_factor=1e-6, end_factor=1, total_iters=args.warmup_epochs * len(train_loader) 207 | # ) 208 | # cosine_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 209 | # optimizer, 210 | # T_max=(args.epochs - args.warmup_epochs) * len(train_loader), eta_min=1e-6 211 | # ) 212 | # lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 213 | # optimizer, 214 | # [warmup_lr_scheduler, cosine_lr_scheduler], 215 | # milestones=[args.warmup_epochs * len(train_loader)] 216 | # ) 217 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 218 | optimizer, 219 | T_max=args.epochs * len(train_loader), eta_min=1e-6 220 | ) 221 | 222 | loss_fn = MultiTeacherDistillLoss( 223 | dt=args.t, teachers=args.teachers, s=args.strategy 224 | ).to(device) 225 | 226 | if args.resume is not None: 227 | checkpoint_file = args.resume 228 | if os.path.isfile(checkpoint_file): 229 | logger.info(f"Loading checkpoint \"{checkpoint_file}\"...") 230 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 231 | else: 232 | logger.error(f"=> No checkpoint found at '{checkpoint_file}'.") 233 | sys.exit() 234 | 235 | model.load_state_dict(checkpoint["state_dict"], strict=True) 236 | args.start_epoch = checkpoint["epoch"] 237 | optimizer.load_state_dict(checkpoint["optimizer"]) 238 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 239 | loss_fn.load_state_dict(checkpoint["loss_fn"]) 240 | logger.info(f"Loaded checkpoint.") 241 | 242 | for epoch in range(args.start_epoch, args.epochs): 243 | train_one_epoch( 244 | model, 245 | train_loader, 246 | optimizer, 247 | lr_scheduler, 248 | loss_fn, 249 | device, 250 | epoch+1, 251 | args.epochs 252 | ) 253 | 254 | if (epoch + 1) % args.snapshot_step == 0 and args.rank == 0: 255 | utils.save_checkpoint( 256 | { 257 | "epoch": epoch+1, 258 | "state_dict": model.state_dict(), 259 | "optimizer": optimizer.state_dict(), 260 | "lr_scheduler": lr_scheduler.state_dict(), 261 | "loss_fn": loss_fn.state_dict() 262 | }, 263 | False, 264 | args.ckpt_dir, 265 | filename=f"checkpoint_{epoch+1}.pt" 266 | ) 267 | 268 | 269 | def train_one_epoch(model, train_loader, optimizer, lr_scheduler, loss_fn, device, epoch, total_epochs): 270 | model.train() 271 | 272 | rank = utils.get_rank() 273 | if rank == 0: 274 | logger = logging.getLogger("svd_distill_train") 275 | else: 276 | logger = None 277 | metric_logger = metric.MetricLogger(logger, delimiter=" ") 278 | header = f'Epoch: [{epoch}/{total_epochs}]' 279 | log_freq = len(train_loader) // 16 if len(train_loader) >= 16 else len(train_loader) 280 | 281 | train_loader.batch_sampler.set_epoch(epoch) 282 | 283 | for batch_idx, batch in metric_logger.log_every(train_loader, log_freq, header=header, iterations=len(train_loader)): 284 | x1, x2, _ = batch 285 | 286 | x1 = x1.to(device) 287 | x2 = x2.to(device) 288 | 289 | predicts = model(x1, x2) 290 | 291 | loss_dict = loss_fn(*predicts) 292 | loss = sum(loss_dict.values()) 293 | 294 | optimizer.zero_grad() 295 | loss.backward() 296 | optimizer.step() 297 | 298 | if rank == 0 and batch_idx % log_freq == 0: 299 | logger.info(f"learning rate: {lr_scheduler.get_last_lr()[0]:.8f}") 300 | logger.info(loss_fn) 301 | lr_scheduler.step() 302 | 303 | loss_dict = {k: v.item() for k, v in loss_dict.items()} 304 | metric_logger.update(**loss_dict) 305 | 306 | metric_logger.sync() 307 | if rank == 0: 308 | logger.info("Averaged stats: " + str(metric_logger)) 309 | 310 | 311 | class TwoCropsTransform: 312 | def __init__(self): 313 | self.transform1 = T.Compose([ 314 | T.RandomResizedCrop(224, scale=(0.4, 1.)), 315 | T.RandomApply([ 316 | T.ColorJitter(0.4, 0.4, 0.2, 0.1) 317 | ], p=0.8), 318 | T.RandomGrayscale(p=0.2), 319 | T.RandomApply([GaussianBlur([.1, 2.])], p=1.), 320 | T.RandomHorizontalFlip(p=0.5), 321 | T.ToTensor(), 322 | T.Normalize( 323 | mean=[0.485, 0.456, 0.406], 324 | std=[0.229, 0.224, 0.225] 325 | ) 326 | ]) 327 | self.transform2 = T.Compose([ 328 | T.RandomResizedCrop(224, scale=(0.4, 1.)), 329 | T.RandomApply([ 330 | T.ColorJitter(0.4, 0.4, 0.2, 0.1) 331 | ], p=0.8), 332 | T.RandomGrayscale(p=0.2), 333 | T.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 334 | T.RandomApply([Solarize()], p=0.2), 335 | T.RandomHorizontalFlip(p=0.5), 336 | T.ToTensor(), 337 | T.Normalize( 338 | mean=[0.485, 0.456, 0.406], 339 | std=[0.229, 0.224, 0.225] 340 | ) 341 | ]) 342 | 343 | def __call__(self, x): 344 | im1 = self.transform1(x) 345 | im2 = self.transform2(x) 346 | return im1, im2 347 | 348 | 349 | class GaussianBlur(object): 350 | def __init__(self, sigma=[.1, 2.]): 351 | self.sigma = sigma 352 | 353 | def __call__(self, x): 354 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 355 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 356 | return x 357 | 358 | 359 | class Solarize(object): 360 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 361 | 362 | def __call__(self, x): 363 | return ImageOps.solarize(x) 364 | 365 | 366 | class LARS(torch.optim.Optimizer): 367 | """ 368 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 369 | """ 370 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 371 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 372 | super().__init__(params, defaults) 373 | 374 | @torch.no_grad() 375 | def step(self): 376 | for g in self.param_groups: 377 | for p in g['params']: 378 | dp = p.grad 379 | 380 | if dp is None: 381 | continue 382 | 383 | if p.ndim > 1: # if not normalization gamma/beta or bias 384 | dp = dp.add(p, alpha=g['weight_decay']) 385 | param_norm = torch.norm(p) 386 | update_norm = torch.norm(dp) 387 | one = torch.ones_like(param_norm) 388 | q = torch.where(param_norm > 0., 389 | torch.where(update_norm > 0, 390 | (g['trust_coefficient'] * param_norm / update_norm), one), 391 | one) 392 | dp = dp.mul(q) 393 | 394 | param_state = self.state[p] 395 | if 'mu' not in param_state: 396 | param_state['mu'] = torch.zeros_like(p) 397 | mu = param_state['mu'] 398 | mu.mul_(g['momentum']).add_(dp) 399 | p.add_(mu, alpha=-g['lr']) 400 | 401 | 402 | 403 | if __name__ == "__main__": 404 | main() 405 | -------------------------------------------------------------------------------- /gld_distill.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import math 8 | import logging 9 | import yaml 10 | import random 11 | import pandas as pd 12 | 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.multiprocessing as mp 17 | import torch.nn as nn 18 | import torchvision.transforms as T 19 | from PIL import Image 20 | 21 | from models import distill_model 22 | from loss import MultiTeacherDistillLoss 23 | import utils 24 | import metric 25 | 26 | 27 | def parse_args(): 28 | """ 29 | Parse input arguments 30 | """ 31 | parser = argparse.ArgumentParser(description="Multi-teacher knowledge distill on Google Landmark V2 cleaned dataset.") 32 | parser.add_argument( 33 | "-a", "--arch", default="resnet18", type=str, metavar="ARCH", 34 | choices=list(distill_model.archs.keys()), help="architecture of student model" 35 | ) 36 | parser.add_argument( 37 | "-ts", "--teachers", required=True, nargs="+", metavar="LIST", 38 | help="teacher models to distill" 39 | ) 40 | parser.add_argument( 41 | "-c", "--ckpt_dir", required=True, type=str, metavar="DIR", 42 | help="directory to save checkpoint files" 43 | ) 44 | parser.add_argument( 45 | "--gld_root_path", default="/path/to/gldv2", type=str, metavar="PATH", 46 | help="frame root path of GLDv2 dataset" 47 | ) 48 | parser.add_argument( 49 | "-b", "--batch_size", default=64, type=int, metavar="N", 50 | help="batch size" 51 | ) 52 | parser.add_argument( 53 | "--num_workers", default=8, type=int, metavar="N", 54 | help="number of data loader workers" 55 | ) 56 | parser.add_argument( 57 | "-r", "--resume", default=None, type=str, metavar="DIR", 58 | help="checkpoint model to resume" 59 | ) 60 | parser.add_argument( 61 | "--path_to_pretrained_weights", default="/path/to/pretrained_weights", type=str, metavar="DIR", 62 | help="path to pretrained teacher models and whitening weights" 63 | ) 64 | parser.add_argument( 65 | "--world_size", default=8, type=int, 66 | help="number of workers" 67 | ) 68 | parser.add_argument( 69 | "--dist_url", default="tcp://localhost:2023", type=str, 70 | help='url used to set up distributed training' 71 | ) 72 | parser.add_argument( 73 | '--rank', default=0, type=int, 74 | help='node rank for distributed training' 75 | ) 76 | parser.add_argument( 77 | "--start_epoch", type=int, metavar="N", default=0, 78 | help="start from epoch i" 79 | ) 80 | parser.add_argument( 81 | "--epochs", default=200, type=int, metavar="N", 82 | help="number of training epochs" 83 | ) 84 | parser.add_argument( 85 | "--lr", default=1e-3, type=float, metavar="N", 86 | help="initial learning rate" 87 | ) 88 | parser.add_argument( 89 | '--wd', '--weight-decay', default=1e-6, type=float, metavar='W', dest='weight_decay', 90 | help='weight decay (default: 1e-6)' 91 | ) 92 | parser.add_argument( 93 | "--snapshot_step", default=10, type=int, metavar="N", 94 | help="interval to dump checkpoint" 95 | ) 96 | parser.add_argument( 97 | '--embed_dim', default=512, type=int, 98 | help='embedding dimension' 99 | ) 100 | parser.add_argument( 101 | '-tt', default=0.05, type=float, 102 | help='teacher distill temperature' 103 | ) 104 | parser.add_argument( 105 | '-st', default=0.05, type=float, 106 | help='student distill temperature' 107 | ) 108 | parser.add_argument( 109 | '-p', default=3, type=float, 110 | help='power rate' 111 | ) 112 | parser.add_argument( 113 | "--imsize", default=512, type=int, 114 | help="input image shape" 115 | ) 116 | parser.add_argument( 117 | "-s", "--strategy", default="maxmin", type=str, 118 | help="similarity fusion strategy" 119 | ) 120 | return parser.parse_args() 121 | 122 | 123 | def main(): 124 | args = parse_args() 125 | 126 | # create checkpoint directory and save configurations 127 | os.makedirs(args.ckpt_dir, exist_ok=True) 128 | with open(os.path.join(args.ckpt_dir, "config.yaml"), "w") as f: 129 | yaml.dump(vars(args), f) 130 | 131 | utils.setup_logger() 132 | logger = logging.getLogger("gld_distill") 133 | 134 | logger.info(vars(args)) 135 | 136 | ngpus_per_node = torch.cuda.device_count() 137 | 138 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 139 | 140 | 141 | def main_worker(gpu, ngpus_per_node, args): 142 | cudnn.benchmark = True 143 | 144 | args.gpu = gpu 145 | 146 | if args.dist_url == "env://" and args.rank == -1: 147 | args.rank = int(os.environ["RANK"]) 148 | args.rank = args.rank * ngpus_per_node + gpu 149 | 150 | if args.rank == 0: 151 | utils.setup_logger(args.ckpt_dir) 152 | else: 153 | utils.setup_logger(log_path=None) 154 | logger = logging.getLogger("worker " + str(args.rank)) 155 | 156 | dist.init_process_group( 157 | backend="nccl", 158 | init_method=args.dist_url, 159 | world_size=args.world_size, 160 | rank=args.rank 161 | ) 162 | 163 | torch.cuda.set_device(args.gpu) 164 | device = torch.device("cuda:"+str(torch.cuda.current_device())) 165 | logger.info(f"Using device cuda:{torch.cuda.current_device()}") 166 | 167 | # parse landmarks 168 | df = pd.read_csv(os.path.join(args.gld_root_path, "meta/train_clean.csv")) 169 | landmark_list = df[["images"]].values.tolist() 170 | landmark_list = list(map(lambda x: x[0].split(" "), landmark_list)) 171 | landmark_list = list(group for group in landmark_list if len(group) >= 2) 172 | 173 | transform = T.Compose([ 174 | T.RandomResizedCrop(args.imsize, scale=(0.4, 1.)), 175 | T.RandomHorizontalFlip(p=0.5), 176 | T.ToTensor(), 177 | T.Normalize( 178 | mean=[0.485, 0.456, 0.406], 179 | std=[0.229, 0.224, 0.225] 180 | ) 181 | ]) 182 | train_dataset = GLDImageDataset(os.path.join(args.gld_root_path, "train"), transform) 183 | train_sampler = DistributedGroupSampler( 184 | args.rank, args.world_size, args.batch_size, landmark_list, 185 | shuffle=True, drop_last=True 186 | ) 187 | train_loader = torch.utils.data.DataLoader( 188 | train_dataset, batch_sampler=train_sampler, 189 | num_workers=args.num_workers, pin_memory=True 190 | ) 191 | 192 | model = distill_model.MultiTeacherDistillModel(args) 193 | logger.info(f"Build model {model.__class__.__name__}") 194 | n_parameters = sum([p.data.nelement() for p in model.parameters()]) 195 | n_trainable_parameters = sum([p.data.nelement() for p in model.parameters() if p.requires_grad]) 196 | logger.info(f"Number of parameters: {n_parameters}") 197 | logger.info(f"Number of trainable parameters: {n_trainable_parameters}") 198 | 199 | model.cuda(device) 200 | 201 | if utils.has_batchnorms(model): 202 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 203 | 204 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 205 | 206 | optimizer = torch.optim.Adam( 207 | filter(lambda p: p.requires_grad, model.parameters()), 208 | lr=args.lr, 209 | weight_decay=args.weight_decay 210 | ) 211 | 212 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 213 | optimizer, 214 | T_max=args.epochs * len(train_loader), eta_min=1e-6 215 | ) 216 | 217 | loss_fn = MultiTeacherDistillLoss( 218 | st=args.st, tt=args.tt, 219 | s=args.strategy, teachers=args.teachers 220 | ).to(device) 221 | 222 | if args.resume is not None: 223 | checkpoint_file = args.resume 224 | if os.path.isfile(checkpoint_file): 225 | logger.info(f"Loading checkpoint \"{checkpoint_file}\"...") 226 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 227 | else: 228 | logger.error(f"=> No checkpoint found at '{checkpoint_file}'.") 229 | sys.exit() 230 | 231 | model.load_state_dict(checkpoint["state_dict"], strict=False) 232 | args.start_epoch = checkpoint["epoch"] 233 | optimizer.load_state_dict(checkpoint["optimizer"]) 234 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 235 | loss_fn.load_state_dict(checkpoint["loss_fn"]) 236 | logger.info(f"Loaded checkpoint.") 237 | 238 | for epoch in range(args.start_epoch, args.epochs): 239 | train_one_epoch( 240 | model, 241 | train_loader, 242 | optimizer, 243 | lr_scheduler, 244 | loss_fn, 245 | device, 246 | epoch, 247 | args.epochs 248 | ) 249 | 250 | if (epoch + 1) % args.snapshot_step == 0 and args.rank == 0: 251 | utils.save_checkpoint( 252 | { 253 | "epoch": epoch+1, 254 | "state_dict": model.state_dict(), 255 | "optimizer": optimizer.state_dict(), 256 | "lr_scheduler": lr_scheduler.state_dict(), 257 | "loss_fn": loss_fn.state_dict() 258 | }, 259 | False, 260 | args.ckpt_dir, 261 | filename=f"checkpoint_{epoch+1}.pt" 262 | ) 263 | 264 | 265 | def train_one_epoch(model, train_loader, optimizer, lr_scheduler, loss_fn, device, epoch, total_epochs): 266 | model.train() 267 | 268 | rank = utils.get_rank() 269 | if rank == 0: 270 | logger = logging.getLogger("gld_distill_train") 271 | else: 272 | logger = None 273 | metric_logger = metric.MetricLogger(logger, delimiter=" ") 274 | header = f'Epoch: [{epoch+1}/{total_epochs}]' 275 | iters_per_epoch = len(train_loader) 276 | log_freq = iters_per_epoch // 16 if iters_per_epoch >= 16 else iters_per_epoch 277 | 278 | train_loader.batch_sampler.set_epoch(epoch) 279 | 280 | for batch_idx, batch in metric_logger.log_every(train_loader, log_freq, header=header, iterations=iters_per_epoch): 281 | x, img_ids = batch 282 | 283 | batch_size = x.size(0) // 2 284 | 285 | x1 = x[:batch_size] 286 | x2 = x[batch_size:] 287 | 288 | x1 = x1.to(device) 289 | x2 = x2.to(device) 290 | 291 | predicts = model(x1, x2) 292 | 293 | loss_dict = loss_fn(*predicts) 294 | loss = sum(loss_dict.values()) 295 | 296 | optimizer.zero_grad() 297 | loss.backward() 298 | optimizer.step() 299 | 300 | if rank == 0 and batch_idx % log_freq == 0: 301 | logger.info(f"learning rate: {lr_scheduler.get_last_lr()[0]:.8f}") 302 | logger.info(loss_fn) 303 | lr_scheduler.step() 304 | 305 | loss_dict = {k: v.item() for k, v in loss_dict.items()} 306 | metric_logger.update(**loss_dict) 307 | 308 | metric_logger.sync() 309 | if rank == 0: 310 | logger.info("Averaged stats: " + str(metric_logger)) 311 | 312 | 313 | class GLDImageDataset(torch.utils.data.dataset.Dataset): 314 | def __init__(self, root_path, transform=None): 315 | super().__init__() 316 | self.root_path = root_path 317 | self.t = transform 318 | 319 | def __getitem__(self, img_id): 320 | img_path = os.path.join( 321 | self.root_path, 322 | img_id[0], img_id[1], img_id[2], 323 | img_id+".jpg" 324 | ) 325 | 326 | img = Image.open(img_path) 327 | img = img.convert("RGB") 328 | if self.t is not None: 329 | img = self.t(img) 330 | 331 | return img, img_id 332 | 333 | def __len__(self): 334 | return len(self.img_id_list) 335 | 336 | 337 | class DistributedSampler(torch.utils.data.sampler.Sampler): 338 | def __init__(self, rank, num_replicas, batch_size, samples, seed=0, shuffle=True, drop_last=True): 339 | self._rank = rank 340 | self._num_replicas = num_replicas 341 | self._batch_size = batch_size 342 | self._samples = samples 343 | self._seed = seed 344 | self._shuffle = shuffle 345 | self._drop_last = drop_last 346 | self._epoch = 0 347 | 348 | if self._drop_last and len(self._samples) % self._num_replicas != 0: 349 | self._num_samples = math.ceil( 350 | (len(self._samples) - self._num_replicas) / self._num_replicas 351 | ) 352 | else: 353 | self._num_samples = math.ceil(len(self._samples) / self._num_replicas) 354 | self._total_size = self._num_samples * self._num_replicas 355 | 356 | def __len__(self): 357 | return math.ceil(self._num_samples/self._batch_size) 358 | 359 | def set_epoch(self, epoch): 360 | self._epoch = epoch 361 | 362 | def _get_subset(self): 363 | if self._shuffle: 364 | # deterministically shuffle based on epoch and seed 365 | g = torch.Generator() 366 | g.manual_seed(self._seed + self._epoch) 367 | indices = torch.randperm(len(self._samples), generator=g).tolist() 368 | else: 369 | indices = list(range(len(self._samples))) 370 | 371 | if not self._drop_last: 372 | # add extra samples to make it evenly divisible 373 | padding_size = self._total_size - len(indices) 374 | if padding_size <= len(indices): 375 | indices += indices[:padding_size] 376 | else: 377 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 378 | else: 379 | # remove tail of data to make it evenly divisible. 380 | indices = indices[:self._total_size] 381 | assert len(indices) == self._total_size 382 | 383 | # subsample 384 | indices = indices[self._rank:self._total_size:self._num_replicas] 385 | assert len(indices) == self._num_samples 386 | 387 | return [self._samples[i] for i in indices] 388 | 389 | 390 | class DistributedGroupSampler(DistributedSampler): 391 | def __init__(self, rank, num_replicas, batch_size, groups, seed=0, shuffle=True, drop_last=True): 392 | super().__init__( 393 | rank, num_replicas, batch_size, groups, 394 | seed=seed, shuffle=shuffle, drop_last=drop_last 395 | ) 396 | random.seed(rank) 397 | 398 | logger = logging.getLogger("dist_group_sampler."+str(rank)) 399 | logger.info(self) 400 | 401 | def __str__(self): 402 | return f"| Distributed Group Sampler | {self._num_samples} groups | iters {len(self)} | {self._batch_size} per batch" 403 | 404 | def __iter__(self): 405 | subset = self._get_subset() 406 | for i in range(0, len(subset), self._batch_size): 407 | groups = subset[i:i+self._batch_size] 408 | img_id_pairs = [random.sample(group, 2) for group in groups] 409 | yield [pair[0] for pair in img_id_pairs] + [pair[1] for pair in img_id_pairs] 410 | 411 | 412 | if __name__ == "__main__": 413 | main() 414 | --------------------------------------------------------------------------------