├── figures ├── 233 └── overview.jpg ├── README.md ├── test_mm_tta.py └── train_universityMM.py /figures/233: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HRT00/CVGL-3D/HEAD/figures/overview.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SkyLink: Unifying Street-Satellite Geo-Localization via UAV-Mediated 3D Scene Alignment 2 | Official implementation of 2025 ACM'MM UAV Challenging paper SkyLink (https://codalab.lisn.upsaclay.fr/competitions/22073), TeamName: XMUSmart. 3 | 4 | ## News 5 | 🚩 **2025.07.07: Comming Soon! Codes will be released upon the paper's publicationd.** 6 | 7 | 🚩 **2025.07.24: Our paper has been accepted by ACM'MM 2025 UAVM. The codes for training has been released!** 8 | 9 | 🚩 **2025.11.03: Our paper has been publisd! The link is: https://dl.acm.org/doi/10.1145/3728482.3757392** 10 | 11 | ## Description 📜 12 | This research mainly focus on ground-satellite geo-localization. We aims at robust feature retrieval under viewpoint variation and propose the novel SkyLink method. Meanwhile, we integrate the 3D scene information constructed from multi-scale UAV images as a bridge between street and satellite viewpoints, and perform feature alignment through self-supervised and cross-view contrastive learning. 13 | 14 | ## Framework 🖇️ 15 | Framework 16 | 17 | ## Requirements 18 | ### Installation 19 | Create a conda environment and install dependencies: 20 | ```bash 21 | conda create -n cvgl python=3.10 22 | conda activate cvgl 23 | 24 | # Install the according versions of torch and torchvision 25 | conda install pytorch torchvision cudatoolkit 26 | ``` 27 | 28 | ## Quick Start 29 | To train SkyLink model, you can run the following command: 30 | ``` 31 | python train_universityMM.py 32 | ``` 33 | To test SkyLink model with test-time augmentation (TTA) in the competition, you can also run the command: 34 | ``` 35 | python test_mm_tta.py 36 | ``` 37 | We reconstruct the point-clouds from multi-view UAV images by VGGT model, you can refer to https://github.com/facebookresearch/vggt. 38 | 39 | ## Acknowledgements 40 | 41 | ## Cite 42 | If you find our work useful, please consider citing: 43 | ```bibtex 44 | @article{zhang2025skylink, 45 | title={SkyLink: Unifying Street-Satellite Geo-Localization via UAV-Mediated 3D Scene Alignment}, 46 | author={Zhang, Hongyang and Liu, Yinhao and Kuang, Zhenyu}, 47 | journal={arXiv preprint arXiv:2509.24783}, 48 | year={2025} 49 | } 50 | ``` 51 | 52 | ## Contact 53 | If you have any question about this project, please feel free to contact hyzhang@stu.xmu.edu.cn. 54 | -------------------------------------------------------------------------------- /test_mm_tta.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 7 | 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | from torch.optim import lr_scheduler 14 | from torch.autograd import Variable 15 | import torch.backends.cudnn as cudnn 16 | import numpy as np 17 | import ttach as tta 18 | 19 | import time 20 | 21 | from cvcities_base.model import TimmModel 22 | from cvcities_base.dataset.university import get_transforms 23 | from utils.image_folder import CustomData160k_sat, CustomData160k_drone 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | print(f"Running on device: {device}") 27 | 28 | # Options 29 | # -------- 30 | parser = argparse.ArgumentParser(description='Training') 31 | parser.add_argument('--gpu_ids',default='4', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 32 | parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') 33 | parser.add_argument('--test_dir',default='/home/zhanghy/MM/data/University-1652/test',type=str, help='./test_data') 34 | parser.add_argument('--batchsize', default=8, type=int, help='batchsize') 35 | parser.add_argument('--views', default=2, type=int, help='views') 36 | parser.add_argument('--query_name', default='query_street_name.txt', type=str,help='load query image') 37 | opt = parser.parse_args() 38 | 39 | str_ids = opt.gpu_ids.split(',') 40 | #which_epoch = opt.which_epoch 41 | test_dir = opt.test_dir 42 | query_name = opt.query_name 43 | ms = [1] 44 | 45 | ###################################################################### 46 | # Load Data 47 | # --------- 48 | # 49 | # We will use torchvision and torch.utils.data packages for loading the 50 | # data. 51 | # 52 | mean = [0.485, 0.456, 0.406] 53 | std = [0.229, 0.224, 0.225] 54 | img_size = (448, 448) 55 | 56 | val_transforms, _, _ = get_transforms(img_size, mean=mean, std=std) 57 | 58 | data_dir = test_dir 59 | 60 | image_datasets = {} 61 | image_datasets['gallery_satellite'] = CustomData160k_sat(os.path.join(data_dir, 'workshop_gallery_satellite'), val_transforms) 62 | image_datasets['query_street'] = CustomData160k_drone(os.path.join(data_dir,'workshop_query_street'), val_transforms, query_name = query_name) 63 | print(image_datasets.keys()) 64 | 65 | 66 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 67 | shuffle=False, num_workers=16) for x in 68 | ['gallery_satellite','query_street']} 69 | 70 | use_gpu = torch.cuda.is_available() 71 | 72 | ###################################################################### 73 | # Extract feature 74 | # ---------------------- 75 | # 76 | # Extract feature from a trained model. 77 | # 78 | def fliplr(img): 79 | '''flip horizontal''' 80 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 81 | img_flip = img.index_select(3,inv_idx) 82 | return img_flip 83 | 84 | def which_view(name): 85 | if 'satellite' in name: 86 | return 1 87 | elif 'street' in name: 88 | return 2 89 | elif 'drone' in name: 90 | return 3 91 | else: 92 | print('unknown view') 93 | return -1 94 | 95 | 96 | tta_transforms = tta.Compose([ 97 | tta.HorizontalFlip(), 98 | tta.Rotate90(angles=[0, 90]), 99 | ]) 100 | 101 | def extract_feature(model, dataloader): 102 | features = [] 103 | model_tta = tta.ClassificationTTAWrapper(model, tta_transforms) # 使用 TTA 包装模型 104 | model_tta.eval() 105 | 106 | if use_gpu: 107 | model_tta = model_tta.to(device) 108 | 109 | for data in dataloader: 110 | img, _ = data 111 | input_img = Variable(img.to(device)) 112 | 113 | with torch.no_grad(): 114 | outputs = model_tta(input_img) # 自动应用所有 TTA 变换并融合结果 115 | 116 | feature = F.normalize(outputs, dim=-1) 117 | features.append(feature.cpu().data) 118 | 119 | features = torch.cat(features, dim=0) 120 | return features 121 | 122 | def get_SatId_160k(img_path): 123 | labels = [] 124 | paths = [] 125 | for path,v in img_path: 126 | labels.append(v) 127 | paths.append(path) 128 | return labels, paths 129 | 130 | import torch 131 | import numpy as np 132 | from sklearn.metrics.pairwise import euclidean_distances 133 | 134 | def compute_distmat(query, gallery): 135 | # L2 normalize 136 | query = query / query.norm(dim=1, keepdim=True) 137 | gallery = gallery / gallery.norm(dim=1, keepdim=True) 138 | distmat = euclidean_distances(query.cpu().numpy(), gallery.cpu().numpy()) 139 | return distmat 140 | 141 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 142 | # Reference: https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/re_ranking.py 143 | original_dist = np.concatenate( 144 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 145 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 146 | axis=0 147 | ) 148 | original_dist = np.power(original_dist, 2).astype(np.float32) 149 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) 150 | V = np.zeros_like(original_dist).astype(np.float32) 151 | initial_rank = np.argsort(original_dist).astype(np.int32) 152 | 153 | query_num = q_g_dist.shape[0] 154 | all_num = original_dist.shape[0] 155 | 156 | for i in range(all_num): 157 | forward_k_neigh_index = initial_rank[i, :k1+1] 158 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1+1] 159 | fi = np.where(backward_k_neigh_index == i)[0] 160 | k_reciprocal_index = forward_k_neigh_index[fi] 161 | weight = np.exp(-original_dist[i, k_reciprocal_index]) 162 | V[i, k_reciprocal_index] = weight / np.sum(weight) 163 | original_dist = original_dist[:query_num, query_num:] 164 | V_qe = V[:query_num, :] 165 | V_ge = V[query_num:, :] 166 | dist = 1 - np.dot(V_qe, V_ge.T) 167 | final_dist = (1 - lambda_value) * original_dist + lambda_value * dist 168 | return final_dist 169 | 170 | def get_result_rank10(qf,gf,gl): 171 | query = qf.view(-1,1) 172 | score = torch.mm(gf, query) 173 | score = score.squeeze(1).cpu() 174 | score = score.numpy() 175 | index = np.argsort(score) 176 | index = index[::-1] 177 | rank10_index = index[0:10] 178 | result_rank10 = gl[rank10_index] 179 | return result_rank10 180 | 181 | 182 | if __name__ == "__main__": 183 | ###################################################################### 184 | # Load Collected data Trained model 185 | print('-------test-----------') 186 | 187 | class Configuration: 188 | 189 | backbone_arch = 'dinov2_vitl14' 190 | model_name = 'dinov2_vitl14_MixVPR' 191 | agg_arch = 'MixVPR' 192 | agg_config = {'in_channels': 1024, 193 | 'in_h': 32, # 受输入图像尺寸的影响 194 | 'in_w': 32, 195 | 'out_channels': 1024, 196 | 'mix_depth': 2, 197 | 'mlp_ratio': 1, 198 | 'out_rows': 4} 199 | layer1 = 7 200 | checkpoint_start = '' 201 | # point clip 202 | num_views: int = 10 203 | backbone_name: str = 'RN101' 204 | backbone_channel: int = 512 205 | adapter_ratio: float = 0.6 206 | adapter_init: float = 0.5 207 | adapter_dropout: float = 0.1 208 | use_pretrained: bool = True 209 | 210 | args = Configuration() 211 | model = TimmModel(model_name=args.model_name, args=args, 212 | pretrained=True, 213 | img_size=img_size, backbone_arch=args.backbone_arch, agg_arch=args.agg_arch, 214 | agg_config=args.agg_config, layer1=args.layer1) 215 | print(model) 216 | 217 | if args.checkpoint_start is not None: 218 | print("Start from:", args.checkpoint_start) 219 | model_state_dict = torch.load(args.checkpoint_start) 220 | model.load_state_dict(model_state_dict, strict=False) 221 | 222 | model = model.eval() 223 | if use_gpu: 224 | # model = model.cuda() 225 | model = model.to(device) 226 | 227 | # Extract feature 228 | since = time.time() 229 | 230 | query_name = 'query_street' #1 231 | gallery_name = 'gallery_satellite' #1 232 | 233 | which_gallery = which_view(gallery_name) 234 | which_query = which_view(query_name) 235 | 236 | gallery_path = image_datasets[gallery_name].imgs 237 | gallery_label, gallery_path = get_SatId_160k(gallery_path) 238 | 239 | print('%d -> %d:'%(which_query, which_gallery)) 240 | 241 | with torch.no_grad(): 242 | print('-------------------extract query feature----------------------') 243 | query_feature = extract_feature(model, dataloaders[query_name]) 244 | print('-------------------extract gallery feature----------------------') 245 | gallery_feature = extract_feature(model,dataloaders[gallery_name]) 246 | print('--------------------------ending extract-------------------------------') 247 | 248 | time_elapsed = time.time() - since 249 | print('Test complete in {:.0f}m {:.0f}s'.format( 250 | time_elapsed // 60, time_elapsed % 60)) 251 | 252 | query_feature = query_feature.to(device) 253 | gallery_feature = gallery_feature.to(device) 254 | 255 | save_filename = 'answer.txt' 256 | if os.path.isfile(save_filename): 257 | os.remove(save_filename) 258 | results_rank10 = [] 259 | print(len(query_feature)) 260 | gallery_label = np.array(gallery_label) 261 | for i in range(len(query_feature)): 262 | result_rank10 = get_result_rank10(query_feature[i], gallery_feature, gallery_label) 263 | results_rank10.append(result_rank10) 264 | 265 | results_rank10 = np.vstack(results_rank10) 266 | if os.path.isfile(save_filename): 267 | os.remove(save_filename) 268 | with open(save_filename, 'w') as f: 269 | for row in results_rank10: 270 | f.write('\t'.join(map(str, row)) + '\n') 271 | -------------------------------------------------------------------------------- /train_universityMM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import shutil 5 | import sys 6 | import torch 7 | from dataclasses import dataclass 8 | from torch.cuda.amp import GradScaler 9 | from torch.utils.data import DataLoader 10 | from transformers import get_constant_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, \ 11 | get_cosine_schedule_with_warmup 12 | 13 | from cvgl_base.dataset.university import U1652DatasetEval, U1652DatasetTrain, get_transforms 14 | from cvgl_base.utils import setup_system, Logger 15 | from cvgl_base.trainer import train 16 | from cvgl_base.evaluate.university import evaluate 17 | from cvgl_base.loss.loss import InfoNCE 18 | from cvgl_base.loss.blocks_infoNCE import blocks_InfoNCE 19 | from cvgl_base.loss.DSA_loss import DSA_loss 20 | from cvgl_base.loss.supcontrast import SupConLoss 21 | from cvgl_base.model import TimmModel 22 | import os 23 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 24 | 25 | @dataclass 26 | class Configuration: 27 | # Model 28 | model = 'dinov2_vitl14_MixVPR' 29 | 30 | # backbone 31 | backbone_arch = 'dinov2_vitl14' 32 | pretrained = True 33 | layer1 = 7 34 | use_cls = True 35 | norm_descs = True 36 | 37 | # Aggregator 聚合方法 38 | agg_arch = 'MixVPR' 39 | agg_config = {'in_channels': 1024, # 768 for vitb14 | 1536 for vitg14 | 1024 for vitl14 40 | 'in_h': 32, # 受输入图像尺寸的影响 41 | 'in_w': 32, 42 | 'out_channels': 1024, 43 | 'mix_depth': 2, 44 | 'mlp_ratio': 1, 45 | 'out_rows': 4} 46 | # Override model image size 47 | img_size: int = 448 48 | new_hight = 448 49 | new_width = 448 50 | 51 | # Training 52 | mixed_precision: bool = True 53 | custom_sampling: bool = True # use custom sampling instead of random 54 | seed = 1 55 | epochs: int = 40 56 | batch_size: int = 4 # keep in mind real_batch_size = 2 * batch_size # 8 for vitb14 | 2 for vitg14 57 | verbose: bool = True 58 | gpu_ids: tuple = (0, 1) # GPU ids for training 59 | 60 | # Eval 61 | batch_size_eval: int = 32 # 64 for vitb14 | 16 for vitg14 | 32 for vitl14 62 | eval_every_n_epoch: int = 1 # eval every n Epoch 63 | normalize_features: bool = True 64 | eval_gallery_n: int = -1 # -1 for all or int 65 | 66 | # Optimizer 67 | clip_grad = 100. # None | float 68 | decay_exclue_bias: bool = False 69 | grad_checkpointing: bool = False # Gradient Checkpointing 70 | use_sgd = True 71 | 72 | # Loss 73 | label_smoothing: float = 0.1 74 | 75 | # Learning Rate 76 | lr: float = 0.0005 # 1 * 10^-4 for ViT | 1 * 10^-1 for CNN 77 | scheduler: str = "cosine" # "polynomial" | "cosine" | "constant" | None 78 | warmup_epochs: int = 0.1 79 | lr_end: float = 0.0001 # only for "polynomial" 80 | 81 | # Dataset 82 | dataset: str = 'U1652-G2S' # 'U1652-D2S' | 'U1652-S2D' 83 | data_folder: str = "your_data_path" 84 | 85 | # Augment Images 86 | prob_flip: float = 0.5 # flipping the sat image and drone image simultaneously 87 | 88 | # Savepath for model checkpoints 89 | model_path: str = "" 90 | 91 | # Eval before training 92 | zero_shot: bool = False 93 | 94 | # Checkpoint to start from 95 | checkpoint_start = None 96 | 97 | # set num_workers to 0 if on Windows 98 | num_workers: int = 0 if os.name == 'nt' else 7 99 | 100 | # train on GPU if available 101 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 102 | 103 | # for better performance 104 | cudnn_benchmark: bool = True 105 | 106 | # make cudnn deterministic 107 | cudnn_deterministic: bool = False 108 | 109 | # point clip 110 | num_views: int = 10 111 | backbone_name: str = 'RN101' 112 | backbone_channel: int = 512 113 | adapter_ratio: float = 0.6 114 | adapter_init: float = 0.5 115 | adapter_dropout: float = 0.09 116 | use_pretrained: bool = True 117 | 118 | # -----------------------------------------------------------------------------# 119 | # Train Config # 120 | # -----------------------------------------------------------------------------# 121 | 122 | config = Configuration() 123 | 124 | if config.dataset == 'U1652-G2S': 125 | config.query_folder_train = f'{config.data_folder}/train/satellite' 126 | config.gallery_folder_train = f'{config.data_folder}/train/street_new' 127 | config.pointcloud_folder_train = f'{config.data_folder}/train/drone_3D' 128 | config.query_folder_test = f'{config.data_folder}/test/query_street' 129 | config.gallery_folder_test = f'{config.data_folder}/test/gallery_satellite' 130 | elif config.dataset == 'U1652-S2G': 131 | config.query_folder_train = f'{config.data_folder}/train/satellite' 132 | config.gallery_folder_train = f'{config.data_folder}/train/street' 133 | config.pointcloud_folder_train = f'{config.data_folder}/train/drone_3D' 134 | config.query_folder_test = f'{config.data_folder}/test/query_satellite' 135 | config.gallery_folder_test = f'{config.data_folder}/test/gallery_street' 136 | 137 | if __name__ == '__main__': 138 | 139 | model_path = "{}/{}/{}".format(config.model_path, 140 | config.model, 141 | time.strftime("%Y-%m-%d_%H%M%S")) 142 | 143 | if not os.path.exists(model_path): 144 | os.makedirs(model_path) 145 | shutil.copyfile(os.path.basename(__file__), "{}/train.py".format(model_path)) 146 | 147 | # Redirect print to both console and log file 148 | sys.stdout = Logger(os.path.join(model_path, 'log.txt')) 149 | 150 | setup_system(seed=config.seed, 151 | cudnn_benchmark=config.cudnn_benchmark, 152 | cudnn_deterministic=config.cudnn_deterministic) 153 | 154 | # -----------------------------------------------------------------------------# 155 | # Model # 156 | # -----------------------------------------------------------------------------# 157 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))) 158 | 159 | print("\nModel: {}".format(config.model)) 160 | 161 | model = TimmModel(args=config, 162 | model_name=config.model, 163 | pretrained=True, 164 | img_size=config.img_size, backbone_arch=config.backbone_arch, agg_arch=config.agg_arch, 165 | agg_config=config.agg_config, layer1=config.layer1, neck='no', num_classes=701,) 166 | print(model) 167 | 168 | data_config = model.get_config() 169 | print(data_config) 170 | mean = data_config["mean"] 171 | std = data_config["std"] 172 | 173 | img_size = (config.img_size, config.img_size) 174 | 175 | # Activate gradient checkpointing 176 | if config.grad_checkpointing: 177 | model.set_grad_checkpointing(True) 178 | 179 | # Load pretrained Checkpoint 180 | if config.checkpoint_start is not None: 181 | print("Start from:", config.checkpoint_start) 182 | model_state_dict = torch.load(config.checkpoint_start) 183 | model.load_state_dict(model_state_dict, strict=False) 184 | 185 | # Data parallel 186 | print("GPUs available:", torch.cuda.device_count()) 187 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1: 188 | model = torch.nn.DataParallel(model, device_ids=config.gpu_ids) 189 | 190 | # Model to device 191 | model = model.to(config.device) 192 | 193 | print("\nImage Size Query:", img_size) 194 | print("Image Size Ground:", img_size) 195 | print("Mean: {}".format(mean)) 196 | print("Std: {}\n".format(std)) 197 | 198 | # -----------------------------------------------------------------------------# 199 | # DataLoader # 200 | # -----------------------------------------------------------------------------# 201 | 202 | # Transforms 203 | val_transforms, train_sat_transforms, train_drone_transforms = get_transforms(img_size, mean=mean, std=std) 204 | 205 | # Train 206 | train_dataset = U1652DatasetTrain(query_folder=config.query_folder_train, 207 | gallery_folder=config.gallery_folder_train, 208 | pointcloud_folder=config.pointcloud_folder_train, 209 | transforms_query=train_sat_transforms, 210 | transforms_gallery=train_drone_transforms, 211 | prob_flip=config.prob_flip, 212 | shuffle_batch_size=config.batch_size, 213 | ) 214 | 215 | train_dataloader = DataLoader(train_dataset, 216 | batch_size=config.batch_size, 217 | num_workers=config.num_workers, 218 | shuffle=not config.custom_sampling, 219 | pin_memory=True) 220 | 221 | # Reference Satellite Images 222 | query_dataset_test = U1652DatasetEval(data_folder=config.query_folder_test, 223 | mode="query", 224 | transforms=val_transforms, 225 | ) 226 | 227 | query_dataloader_test = DataLoader(query_dataset_test, 228 | batch_size=config.batch_size_eval, 229 | num_workers=config.num_workers, 230 | shuffle=False, 231 | pin_memory=True) 232 | 233 | # Query Ground Images Test 234 | gallery_dataset_test = U1652DatasetEval(data_folder=config.gallery_folder_test, 235 | mode="gallery", 236 | transforms=val_transforms, 237 | sample_ids=query_dataset_test.get_sample_ids(), 238 | gallery_n=config.eval_gallery_n, 239 | ) 240 | 241 | gallery_dataloader_test = DataLoader(gallery_dataset_test, 242 | batch_size=config.batch_size_eval, 243 | num_workers=config.num_workers, 244 | shuffle=False, 245 | pin_memory=True) 246 | 247 | print("Query Images Test:", len(query_dataset_test)) 248 | print("Gallery Images Test:", len(gallery_dataset_test)) 249 | 250 | # -----------------------------------------------------------------------------# 251 | # Loss # 252 | # -----------------------------------------------------------------------------# 253 | 254 | loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing) 255 | loss_function1 = InfoNCE(loss_function=loss_fn, 256 | device=config.device, 257 | ) 258 | loss_function2 = blocks_InfoNCE(loss_function=loss_fn, device=config.device,) 259 | loss_function3 = DSA_loss(loss_function=loss_fn, device=config.device,) 260 | loss_function4 = SupConLoss(device=config.device) 261 | 262 | loss_function = { 263 | 'InfoNCE': loss_function1, 264 | 'blocks_InfoNCE': loss_function2, 265 | 'DSA': loss_function3, 266 | 'SupCon': loss_function4, 267 | } 268 | 269 | if config.mixed_precision: 270 | scaler = GradScaler(init_scale=2. ** 10) 271 | else: 272 | scaler = None 273 | 274 | # -----------------------------------------------------------------------------# 275 | # optimizer # 276 | # -----------------------------------------------------------------------------# 277 | 278 | if config.decay_exclue_bias: 279 | param_optimizer = list(model.named_parameters()) 280 | no_decay = ["bias", "LayerNorm.bias"] 281 | optimizer_parameters = [ 282 | { 283 | "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 284 | "weight_decay": 0.01, 285 | }, 286 | { 287 | "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 288 | "weight_decay": 0.0, 289 | }, 290 | ] 291 | optimizer = torch.optim.AdamW(optimizer_parameters, lr=config.lr) 292 | else: 293 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr) 294 | 295 | if config.use_sgd: 296 | optimizer = torch.optim.SGD(model.parameters(), lr=config.lr) 297 | 298 | # -----------------------------------------------------------------------------# 299 | # Scheduler # 300 | # -----------------------------------------------------------------------------# 301 | 302 | train_steps = len(train_dataloader) * config.epochs 303 | warmup_steps = len(train_dataloader) * config.warmup_epochs 304 | 305 | if config.scheduler == "polynomial": 306 | print("\nScheduler: polynomial - max LR: {} - end LR: {}".format(config.lr, config.lr_end)) 307 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, 308 | num_training_steps=train_steps, 309 | lr_end=config.lr_end, 310 | power=1.5, 311 | num_warmup_steps=warmup_steps) 312 | elif config.scheduler == "cosine": 313 | print("\nScheduler: cosine - max LR: {}".format(config.lr)) 314 | scheduler = get_cosine_schedule_with_warmup(optimizer, 315 | num_training_steps=train_steps, 316 | num_warmup_steps=warmup_steps) 317 | elif config.scheduler == "constant": 318 | print("\nScheduler: constant - max LR: {}".format(config.lr)) 319 | scheduler = get_constant_schedule_with_warmup(optimizer, 320 | num_warmup_steps=warmup_steps) 321 | else: 322 | scheduler = None 323 | 324 | print("Warmup Epochs: {} - Warmup Steps: {}".format(str(config.warmup_epochs).ljust(2), warmup_steps)) 325 | print("Train Epochs: {} - Train Steps: {}".format(config.epochs, train_steps)) 326 | 327 | # -----------------------------------------------------------------------------# 328 | # Zero Shot # 329 | # -----------------------------------------------------------------------------# 330 | if config.zero_shot: 331 | print("\n{}[{}]{}".format(30 * "-", "Zero Shot", 30 * "-")) 332 | 333 | r1_test = evaluate(config=config, 334 | model=model, 335 | query_loader=query_dataloader_test, 336 | gallery_loader=gallery_dataloader_test, 337 | ranks=[1, 5, 10], 338 | step_size=1000, 339 | cleanup=True) 340 | 341 | # -----------------------------------------------------------------------------# 342 | # Shuffle # 343 | # -----------------------------------------------------------------------------# 344 | if config.custom_sampling: 345 | train_dataloader.dataset.shuffle() 346 | 347 | # -----------------------------------------------------------------------------# 348 | # Train # 349 | # -----------------------------------------------------------------------------# 350 | start_epoch = 0 351 | best_score = 0 352 | 353 | for epoch in range(1, config.epochs + 1): 354 | 355 | print("\n{}[{}/Epoch: {}]{}".format(30*"-",time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), epoch, 30*"-")) 356 | 357 | train_loss = train(config, 358 | model, 359 | dataloader=train_dataloader, 360 | loss_function=loss_function, 361 | optimizer=optimizer, 362 | scheduler=scheduler, 363 | scaler=scaler) 364 | 365 | print("Epoch: {}, Train Loss = {:.3f}, Lr = {:.6f}".format(epoch, 366 | train_loss, 367 | optimizer.param_groups[0]['lr'])) 368 | 369 | # evaluate 370 | if (epoch % config.eval_every_n_epoch == 0 and epoch > 1) or epoch == config.epochs: 371 | 372 | print("\n{}[{}]{}".format(30 * "-", "Evaluate", 30 * "-")) 373 | 374 | r1_test = evaluate(config=config, 375 | model=model, 376 | query_loader=query_dataloader_test, 377 | gallery_loader=gallery_dataloader_test, 378 | ranks=[1, 5, 10], 379 | step_size=1000, 380 | cleanup=True) 381 | 382 | if r1_test > best_score: 383 | 384 | best_score = r1_test 385 | 386 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1: 387 | torch.save(model.module.state_dict(), 388 | '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test)) 389 | else: 390 | torch.save(model.state_dict(), '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test)) 391 | elif r1_test > 26.0: 392 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1: 393 | torch.save(model.module.state_dict(), 394 | '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test)) 395 | else: 396 | torch.save(model.state_dict(), '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test)) 397 | 398 | if config.custom_sampling: 399 | train_dataloader.dataset.shuffle() 400 | 401 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1: 402 | torch.save(model.module.state_dict(), '{}/weights_end.pth'.format(model_path)) 403 | else: 404 | torch.save(model.state_dict(), '{}/weights_end.pth'.format(model_path)) 405 | --------------------------------------------------------------------------------