├── figures
├── 233
└── overview.jpg
├── README.md
├── test_mm_tta.py
└── train_universityMM.py
/figures/233:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HRT00/CVGL-3D/HEAD/figures/overview.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SkyLink: Unifying Street-Satellite Geo-Localization via UAV-Mediated 3D Scene Alignment
2 | Official implementation of 2025 ACM'MM UAV Challenging paper SkyLink (https://codalab.lisn.upsaclay.fr/competitions/22073), TeamName: XMUSmart.
3 |
4 | ## News
5 | 🚩 **2025.07.07: Comming Soon! Codes will be released upon the paper's publicationd.**
6 |
7 | 🚩 **2025.07.24: Our paper has been accepted by ACM'MM 2025 UAVM. The codes for training has been released!**
8 |
9 | 🚩 **2025.11.03: Our paper has been publisd! The link is: https://dl.acm.org/doi/10.1145/3728482.3757392**
10 |
11 | ## Description 📜
12 | This research mainly focus on ground-satellite geo-localization. We aims at robust feature retrieval under viewpoint variation and propose the novel SkyLink method. Meanwhile, we integrate the 3D scene information constructed from multi-scale UAV images as a bridge between street and satellite viewpoints, and perform feature alignment through self-supervised and cross-view contrastive learning.
13 |
14 | ## Framework 🖇️
15 |
 |
16 |
17 | ## Requirements
18 | ### Installation
19 | Create a conda environment and install dependencies:
20 | ```bash
21 | conda create -n cvgl python=3.10
22 | conda activate cvgl
23 |
24 | # Install the according versions of torch and torchvision
25 | conda install pytorch torchvision cudatoolkit
26 | ```
27 |
28 | ## Quick Start
29 | To train SkyLink model, you can run the following command:
30 | ```
31 | python train_universityMM.py
32 | ```
33 | To test SkyLink model with test-time augmentation (TTA) in the competition, you can also run the command:
34 | ```
35 | python test_mm_tta.py
36 | ```
37 | We reconstruct the point-clouds from multi-view UAV images by VGGT model, you can refer to https://github.com/facebookresearch/vggt.
38 |
39 | ## Acknowledgements
40 |
41 | ## Cite
42 | If you find our work useful, please consider citing:
43 | ```bibtex
44 | @article{zhang2025skylink,
45 | title={SkyLink: Unifying Street-Satellite Geo-Localization via UAV-Mediated 3D Scene Alignment},
46 | author={Zhang, Hongyang and Liu, Yinhao and Kuang, Zhenyu},
47 | journal={arXiv preprint arXiv:2509.24783},
48 | year={2025}
49 | }
50 | ```
51 |
52 | ## Contact
53 | If you have any question about this project, please feel free to contact hyzhang@stu.xmu.edu.cn.
54 |
--------------------------------------------------------------------------------
/test_mm_tta.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import os
6 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
7 |
8 | import argparse
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.nn.functional as F
13 | from torch.optim import lr_scheduler
14 | from torch.autograd import Variable
15 | import torch.backends.cudnn as cudnn
16 | import numpy as np
17 | import ttach as tta
18 |
19 | import time
20 |
21 | from cvcities_base.model import TimmModel
22 | from cvcities_base.dataset.university import get_transforms
23 | from utils.image_folder import CustomData160k_sat, CustomData160k_drone
24 |
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | print(f"Running on device: {device}")
27 |
28 | # Options
29 | # --------
30 | parser = argparse.ArgumentParser(description='Training')
31 | parser.add_argument('--gpu_ids',default='4', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
32 | parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
33 | parser.add_argument('--test_dir',default='/home/zhanghy/MM/data/University-1652/test',type=str, help='./test_data')
34 | parser.add_argument('--batchsize', default=8, type=int, help='batchsize')
35 | parser.add_argument('--views', default=2, type=int, help='views')
36 | parser.add_argument('--query_name', default='query_street_name.txt', type=str,help='load query image')
37 | opt = parser.parse_args()
38 |
39 | str_ids = opt.gpu_ids.split(',')
40 | #which_epoch = opt.which_epoch
41 | test_dir = opt.test_dir
42 | query_name = opt.query_name
43 | ms = [1]
44 |
45 | ######################################################################
46 | # Load Data
47 | # ---------
48 | #
49 | # We will use torchvision and torch.utils.data packages for loading the
50 | # data.
51 | #
52 | mean = [0.485, 0.456, 0.406]
53 | std = [0.229, 0.224, 0.225]
54 | img_size = (448, 448)
55 |
56 | val_transforms, _, _ = get_transforms(img_size, mean=mean, std=std)
57 |
58 | data_dir = test_dir
59 |
60 | image_datasets = {}
61 | image_datasets['gallery_satellite'] = CustomData160k_sat(os.path.join(data_dir, 'workshop_gallery_satellite'), val_transforms)
62 | image_datasets['query_street'] = CustomData160k_drone(os.path.join(data_dir,'workshop_query_street'), val_transforms, query_name = query_name)
63 | print(image_datasets.keys())
64 |
65 |
66 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
67 | shuffle=False, num_workers=16) for x in
68 | ['gallery_satellite','query_street']}
69 |
70 | use_gpu = torch.cuda.is_available()
71 |
72 | ######################################################################
73 | # Extract feature
74 | # ----------------------
75 | #
76 | # Extract feature from a trained model.
77 | #
78 | def fliplr(img):
79 | '''flip horizontal'''
80 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
81 | img_flip = img.index_select(3,inv_idx)
82 | return img_flip
83 |
84 | def which_view(name):
85 | if 'satellite' in name:
86 | return 1
87 | elif 'street' in name:
88 | return 2
89 | elif 'drone' in name:
90 | return 3
91 | else:
92 | print('unknown view')
93 | return -1
94 |
95 |
96 | tta_transforms = tta.Compose([
97 | tta.HorizontalFlip(),
98 | tta.Rotate90(angles=[0, 90]),
99 | ])
100 |
101 | def extract_feature(model, dataloader):
102 | features = []
103 | model_tta = tta.ClassificationTTAWrapper(model, tta_transforms) # 使用 TTA 包装模型
104 | model_tta.eval()
105 |
106 | if use_gpu:
107 | model_tta = model_tta.to(device)
108 |
109 | for data in dataloader:
110 | img, _ = data
111 | input_img = Variable(img.to(device))
112 |
113 | with torch.no_grad():
114 | outputs = model_tta(input_img) # 自动应用所有 TTA 变换并融合结果
115 |
116 | feature = F.normalize(outputs, dim=-1)
117 | features.append(feature.cpu().data)
118 |
119 | features = torch.cat(features, dim=0)
120 | return features
121 |
122 | def get_SatId_160k(img_path):
123 | labels = []
124 | paths = []
125 | for path,v in img_path:
126 | labels.append(v)
127 | paths.append(path)
128 | return labels, paths
129 |
130 | import torch
131 | import numpy as np
132 | from sklearn.metrics.pairwise import euclidean_distances
133 |
134 | def compute_distmat(query, gallery):
135 | # L2 normalize
136 | query = query / query.norm(dim=1, keepdim=True)
137 | gallery = gallery / gallery.norm(dim=1, keepdim=True)
138 | distmat = euclidean_distances(query.cpu().numpy(), gallery.cpu().numpy())
139 | return distmat
140 |
141 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
142 | # Reference: https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/re_ranking.py
143 | original_dist = np.concatenate(
144 | [np.concatenate([q_q_dist, q_g_dist], axis=1),
145 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
146 | axis=0
147 | )
148 | original_dist = np.power(original_dist, 2).astype(np.float32)
149 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0))
150 | V = np.zeros_like(original_dist).astype(np.float32)
151 | initial_rank = np.argsort(original_dist).astype(np.int32)
152 |
153 | query_num = q_g_dist.shape[0]
154 | all_num = original_dist.shape[0]
155 |
156 | for i in range(all_num):
157 | forward_k_neigh_index = initial_rank[i, :k1+1]
158 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1+1]
159 | fi = np.where(backward_k_neigh_index == i)[0]
160 | k_reciprocal_index = forward_k_neigh_index[fi]
161 | weight = np.exp(-original_dist[i, k_reciprocal_index])
162 | V[i, k_reciprocal_index] = weight / np.sum(weight)
163 | original_dist = original_dist[:query_num, query_num:]
164 | V_qe = V[:query_num, :]
165 | V_ge = V[query_num:, :]
166 | dist = 1 - np.dot(V_qe, V_ge.T)
167 | final_dist = (1 - lambda_value) * original_dist + lambda_value * dist
168 | return final_dist
169 |
170 | def get_result_rank10(qf,gf,gl):
171 | query = qf.view(-1,1)
172 | score = torch.mm(gf, query)
173 | score = score.squeeze(1).cpu()
174 | score = score.numpy()
175 | index = np.argsort(score)
176 | index = index[::-1]
177 | rank10_index = index[0:10]
178 | result_rank10 = gl[rank10_index]
179 | return result_rank10
180 |
181 |
182 | if __name__ == "__main__":
183 | ######################################################################
184 | # Load Collected data Trained model
185 | print('-------test-----------')
186 |
187 | class Configuration:
188 |
189 | backbone_arch = 'dinov2_vitl14'
190 | model_name = 'dinov2_vitl14_MixVPR'
191 | agg_arch = 'MixVPR'
192 | agg_config = {'in_channels': 1024,
193 | 'in_h': 32, # 受输入图像尺寸的影响
194 | 'in_w': 32,
195 | 'out_channels': 1024,
196 | 'mix_depth': 2,
197 | 'mlp_ratio': 1,
198 | 'out_rows': 4}
199 | layer1 = 7
200 | checkpoint_start = ''
201 | # point clip
202 | num_views: int = 10
203 | backbone_name: str = 'RN101'
204 | backbone_channel: int = 512
205 | adapter_ratio: float = 0.6
206 | adapter_init: float = 0.5
207 | adapter_dropout: float = 0.1
208 | use_pretrained: bool = True
209 |
210 | args = Configuration()
211 | model = TimmModel(model_name=args.model_name, args=args,
212 | pretrained=True,
213 | img_size=img_size, backbone_arch=args.backbone_arch, agg_arch=args.agg_arch,
214 | agg_config=args.agg_config, layer1=args.layer1)
215 | print(model)
216 |
217 | if args.checkpoint_start is not None:
218 | print("Start from:", args.checkpoint_start)
219 | model_state_dict = torch.load(args.checkpoint_start)
220 | model.load_state_dict(model_state_dict, strict=False)
221 |
222 | model = model.eval()
223 | if use_gpu:
224 | # model = model.cuda()
225 | model = model.to(device)
226 |
227 | # Extract feature
228 | since = time.time()
229 |
230 | query_name = 'query_street' #1
231 | gallery_name = 'gallery_satellite' #1
232 |
233 | which_gallery = which_view(gallery_name)
234 | which_query = which_view(query_name)
235 |
236 | gallery_path = image_datasets[gallery_name].imgs
237 | gallery_label, gallery_path = get_SatId_160k(gallery_path)
238 |
239 | print('%d -> %d:'%(which_query, which_gallery))
240 |
241 | with torch.no_grad():
242 | print('-------------------extract query feature----------------------')
243 | query_feature = extract_feature(model, dataloaders[query_name])
244 | print('-------------------extract gallery feature----------------------')
245 | gallery_feature = extract_feature(model,dataloaders[gallery_name])
246 | print('--------------------------ending extract-------------------------------')
247 |
248 | time_elapsed = time.time() - since
249 | print('Test complete in {:.0f}m {:.0f}s'.format(
250 | time_elapsed // 60, time_elapsed % 60))
251 |
252 | query_feature = query_feature.to(device)
253 | gallery_feature = gallery_feature.to(device)
254 |
255 | save_filename = 'answer.txt'
256 | if os.path.isfile(save_filename):
257 | os.remove(save_filename)
258 | results_rank10 = []
259 | print(len(query_feature))
260 | gallery_label = np.array(gallery_label)
261 | for i in range(len(query_feature)):
262 | result_rank10 = get_result_rank10(query_feature[i], gallery_feature, gallery_label)
263 | results_rank10.append(result_rank10)
264 |
265 | results_rank10 = np.vstack(results_rank10)
266 | if os.path.isfile(save_filename):
267 | os.remove(save_filename)
268 | with open(save_filename, 'w') as f:
269 | for row in results_rank10:
270 | f.write('\t'.join(map(str, row)) + '\n')
271 |
--------------------------------------------------------------------------------
/train_universityMM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import shutil
5 | import sys
6 | import torch
7 | from dataclasses import dataclass
8 | from torch.cuda.amp import GradScaler
9 | from torch.utils.data import DataLoader
10 | from transformers import get_constant_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, \
11 | get_cosine_schedule_with_warmup
12 |
13 | from cvgl_base.dataset.university import U1652DatasetEval, U1652DatasetTrain, get_transforms
14 | from cvgl_base.utils import setup_system, Logger
15 | from cvgl_base.trainer import train
16 | from cvgl_base.evaluate.university import evaluate
17 | from cvgl_base.loss.loss import InfoNCE
18 | from cvgl_base.loss.blocks_infoNCE import blocks_InfoNCE
19 | from cvgl_base.loss.DSA_loss import DSA_loss
20 | from cvgl_base.loss.supcontrast import SupConLoss
21 | from cvgl_base.model import TimmModel
22 | import os
23 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
24 |
25 | @dataclass
26 | class Configuration:
27 | # Model
28 | model = 'dinov2_vitl14_MixVPR'
29 |
30 | # backbone
31 | backbone_arch = 'dinov2_vitl14'
32 | pretrained = True
33 | layer1 = 7
34 | use_cls = True
35 | norm_descs = True
36 |
37 | # Aggregator 聚合方法
38 | agg_arch = 'MixVPR'
39 | agg_config = {'in_channels': 1024, # 768 for vitb14 | 1536 for vitg14 | 1024 for vitl14
40 | 'in_h': 32, # 受输入图像尺寸的影响
41 | 'in_w': 32,
42 | 'out_channels': 1024,
43 | 'mix_depth': 2,
44 | 'mlp_ratio': 1,
45 | 'out_rows': 4}
46 | # Override model image size
47 | img_size: int = 448
48 | new_hight = 448
49 | new_width = 448
50 |
51 | # Training
52 | mixed_precision: bool = True
53 | custom_sampling: bool = True # use custom sampling instead of random
54 | seed = 1
55 | epochs: int = 40
56 | batch_size: int = 4 # keep in mind real_batch_size = 2 * batch_size # 8 for vitb14 | 2 for vitg14
57 | verbose: bool = True
58 | gpu_ids: tuple = (0, 1) # GPU ids for training
59 |
60 | # Eval
61 | batch_size_eval: int = 32 # 64 for vitb14 | 16 for vitg14 | 32 for vitl14
62 | eval_every_n_epoch: int = 1 # eval every n Epoch
63 | normalize_features: bool = True
64 | eval_gallery_n: int = -1 # -1 for all or int
65 |
66 | # Optimizer
67 | clip_grad = 100. # None | float
68 | decay_exclue_bias: bool = False
69 | grad_checkpointing: bool = False # Gradient Checkpointing
70 | use_sgd = True
71 |
72 | # Loss
73 | label_smoothing: float = 0.1
74 |
75 | # Learning Rate
76 | lr: float = 0.0005 # 1 * 10^-4 for ViT | 1 * 10^-1 for CNN
77 | scheduler: str = "cosine" # "polynomial" | "cosine" | "constant" | None
78 | warmup_epochs: int = 0.1
79 | lr_end: float = 0.0001 # only for "polynomial"
80 |
81 | # Dataset
82 | dataset: str = 'U1652-G2S' # 'U1652-D2S' | 'U1652-S2D'
83 | data_folder: str = "your_data_path"
84 |
85 | # Augment Images
86 | prob_flip: float = 0.5 # flipping the sat image and drone image simultaneously
87 |
88 | # Savepath for model checkpoints
89 | model_path: str = ""
90 |
91 | # Eval before training
92 | zero_shot: bool = False
93 |
94 | # Checkpoint to start from
95 | checkpoint_start = None
96 |
97 | # set num_workers to 0 if on Windows
98 | num_workers: int = 0 if os.name == 'nt' else 7
99 |
100 | # train on GPU if available
101 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
102 |
103 | # for better performance
104 | cudnn_benchmark: bool = True
105 |
106 | # make cudnn deterministic
107 | cudnn_deterministic: bool = False
108 |
109 | # point clip
110 | num_views: int = 10
111 | backbone_name: str = 'RN101'
112 | backbone_channel: int = 512
113 | adapter_ratio: float = 0.6
114 | adapter_init: float = 0.5
115 | adapter_dropout: float = 0.09
116 | use_pretrained: bool = True
117 |
118 | # -----------------------------------------------------------------------------#
119 | # Train Config #
120 | # -----------------------------------------------------------------------------#
121 |
122 | config = Configuration()
123 |
124 | if config.dataset == 'U1652-G2S':
125 | config.query_folder_train = f'{config.data_folder}/train/satellite'
126 | config.gallery_folder_train = f'{config.data_folder}/train/street_new'
127 | config.pointcloud_folder_train = f'{config.data_folder}/train/drone_3D'
128 | config.query_folder_test = f'{config.data_folder}/test/query_street'
129 | config.gallery_folder_test = f'{config.data_folder}/test/gallery_satellite'
130 | elif config.dataset == 'U1652-S2G':
131 | config.query_folder_train = f'{config.data_folder}/train/satellite'
132 | config.gallery_folder_train = f'{config.data_folder}/train/street'
133 | config.pointcloud_folder_train = f'{config.data_folder}/train/drone_3D'
134 | config.query_folder_test = f'{config.data_folder}/test/query_satellite'
135 | config.gallery_folder_test = f'{config.data_folder}/test/gallery_street'
136 |
137 | if __name__ == '__main__':
138 |
139 | model_path = "{}/{}/{}".format(config.model_path,
140 | config.model,
141 | time.strftime("%Y-%m-%d_%H%M%S"))
142 |
143 | if not os.path.exists(model_path):
144 | os.makedirs(model_path)
145 | shutil.copyfile(os.path.basename(__file__), "{}/train.py".format(model_path))
146 |
147 | # Redirect print to both console and log file
148 | sys.stdout = Logger(os.path.join(model_path, 'log.txt'))
149 |
150 | setup_system(seed=config.seed,
151 | cudnn_benchmark=config.cudnn_benchmark,
152 | cudnn_deterministic=config.cudnn_deterministic)
153 |
154 | # -----------------------------------------------------------------------------#
155 | # Model #
156 | # -----------------------------------------------------------------------------#
157 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
158 |
159 | print("\nModel: {}".format(config.model))
160 |
161 | model = TimmModel(args=config,
162 | model_name=config.model,
163 | pretrained=True,
164 | img_size=config.img_size, backbone_arch=config.backbone_arch, agg_arch=config.agg_arch,
165 | agg_config=config.agg_config, layer1=config.layer1, neck='no', num_classes=701,)
166 | print(model)
167 |
168 | data_config = model.get_config()
169 | print(data_config)
170 | mean = data_config["mean"]
171 | std = data_config["std"]
172 |
173 | img_size = (config.img_size, config.img_size)
174 |
175 | # Activate gradient checkpointing
176 | if config.grad_checkpointing:
177 | model.set_grad_checkpointing(True)
178 |
179 | # Load pretrained Checkpoint
180 | if config.checkpoint_start is not None:
181 | print("Start from:", config.checkpoint_start)
182 | model_state_dict = torch.load(config.checkpoint_start)
183 | model.load_state_dict(model_state_dict, strict=False)
184 |
185 | # Data parallel
186 | print("GPUs available:", torch.cuda.device_count())
187 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
188 | model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)
189 |
190 | # Model to device
191 | model = model.to(config.device)
192 |
193 | print("\nImage Size Query:", img_size)
194 | print("Image Size Ground:", img_size)
195 | print("Mean: {}".format(mean))
196 | print("Std: {}\n".format(std))
197 |
198 | # -----------------------------------------------------------------------------#
199 | # DataLoader #
200 | # -----------------------------------------------------------------------------#
201 |
202 | # Transforms
203 | val_transforms, train_sat_transforms, train_drone_transforms = get_transforms(img_size, mean=mean, std=std)
204 |
205 | # Train
206 | train_dataset = U1652DatasetTrain(query_folder=config.query_folder_train,
207 | gallery_folder=config.gallery_folder_train,
208 | pointcloud_folder=config.pointcloud_folder_train,
209 | transforms_query=train_sat_transforms,
210 | transforms_gallery=train_drone_transforms,
211 | prob_flip=config.prob_flip,
212 | shuffle_batch_size=config.batch_size,
213 | )
214 |
215 | train_dataloader = DataLoader(train_dataset,
216 | batch_size=config.batch_size,
217 | num_workers=config.num_workers,
218 | shuffle=not config.custom_sampling,
219 | pin_memory=True)
220 |
221 | # Reference Satellite Images
222 | query_dataset_test = U1652DatasetEval(data_folder=config.query_folder_test,
223 | mode="query",
224 | transforms=val_transforms,
225 | )
226 |
227 | query_dataloader_test = DataLoader(query_dataset_test,
228 | batch_size=config.batch_size_eval,
229 | num_workers=config.num_workers,
230 | shuffle=False,
231 | pin_memory=True)
232 |
233 | # Query Ground Images Test
234 | gallery_dataset_test = U1652DatasetEval(data_folder=config.gallery_folder_test,
235 | mode="gallery",
236 | transforms=val_transforms,
237 | sample_ids=query_dataset_test.get_sample_ids(),
238 | gallery_n=config.eval_gallery_n,
239 | )
240 |
241 | gallery_dataloader_test = DataLoader(gallery_dataset_test,
242 | batch_size=config.batch_size_eval,
243 | num_workers=config.num_workers,
244 | shuffle=False,
245 | pin_memory=True)
246 |
247 | print("Query Images Test:", len(query_dataset_test))
248 | print("Gallery Images Test:", len(gallery_dataset_test))
249 |
250 | # -----------------------------------------------------------------------------#
251 | # Loss #
252 | # -----------------------------------------------------------------------------#
253 |
254 | loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
255 | loss_function1 = InfoNCE(loss_function=loss_fn,
256 | device=config.device,
257 | )
258 | loss_function2 = blocks_InfoNCE(loss_function=loss_fn, device=config.device,)
259 | loss_function3 = DSA_loss(loss_function=loss_fn, device=config.device,)
260 | loss_function4 = SupConLoss(device=config.device)
261 |
262 | loss_function = {
263 | 'InfoNCE': loss_function1,
264 | 'blocks_InfoNCE': loss_function2,
265 | 'DSA': loss_function3,
266 | 'SupCon': loss_function4,
267 | }
268 |
269 | if config.mixed_precision:
270 | scaler = GradScaler(init_scale=2. ** 10)
271 | else:
272 | scaler = None
273 |
274 | # -----------------------------------------------------------------------------#
275 | # optimizer #
276 | # -----------------------------------------------------------------------------#
277 |
278 | if config.decay_exclue_bias:
279 | param_optimizer = list(model.named_parameters())
280 | no_decay = ["bias", "LayerNorm.bias"]
281 | optimizer_parameters = [
282 | {
283 | "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
284 | "weight_decay": 0.01,
285 | },
286 | {
287 | "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
288 | "weight_decay": 0.0,
289 | },
290 | ]
291 | optimizer = torch.optim.AdamW(optimizer_parameters, lr=config.lr)
292 | else:
293 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
294 |
295 | if config.use_sgd:
296 | optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)
297 |
298 | # -----------------------------------------------------------------------------#
299 | # Scheduler #
300 | # -----------------------------------------------------------------------------#
301 |
302 | train_steps = len(train_dataloader) * config.epochs
303 | warmup_steps = len(train_dataloader) * config.warmup_epochs
304 |
305 | if config.scheduler == "polynomial":
306 | print("\nScheduler: polynomial - max LR: {} - end LR: {}".format(config.lr, config.lr_end))
307 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer,
308 | num_training_steps=train_steps,
309 | lr_end=config.lr_end,
310 | power=1.5,
311 | num_warmup_steps=warmup_steps)
312 | elif config.scheduler == "cosine":
313 | print("\nScheduler: cosine - max LR: {}".format(config.lr))
314 | scheduler = get_cosine_schedule_with_warmup(optimizer,
315 | num_training_steps=train_steps,
316 | num_warmup_steps=warmup_steps)
317 | elif config.scheduler == "constant":
318 | print("\nScheduler: constant - max LR: {}".format(config.lr))
319 | scheduler = get_constant_schedule_with_warmup(optimizer,
320 | num_warmup_steps=warmup_steps)
321 | else:
322 | scheduler = None
323 |
324 | print("Warmup Epochs: {} - Warmup Steps: {}".format(str(config.warmup_epochs).ljust(2), warmup_steps))
325 | print("Train Epochs: {} - Train Steps: {}".format(config.epochs, train_steps))
326 |
327 | # -----------------------------------------------------------------------------#
328 | # Zero Shot #
329 | # -----------------------------------------------------------------------------#
330 | if config.zero_shot:
331 | print("\n{}[{}]{}".format(30 * "-", "Zero Shot", 30 * "-"))
332 |
333 | r1_test = evaluate(config=config,
334 | model=model,
335 | query_loader=query_dataloader_test,
336 | gallery_loader=gallery_dataloader_test,
337 | ranks=[1, 5, 10],
338 | step_size=1000,
339 | cleanup=True)
340 |
341 | # -----------------------------------------------------------------------------#
342 | # Shuffle #
343 | # -----------------------------------------------------------------------------#
344 | if config.custom_sampling:
345 | train_dataloader.dataset.shuffle()
346 |
347 | # -----------------------------------------------------------------------------#
348 | # Train #
349 | # -----------------------------------------------------------------------------#
350 | start_epoch = 0
351 | best_score = 0
352 |
353 | for epoch in range(1, config.epochs + 1):
354 |
355 | print("\n{}[{}/Epoch: {}]{}".format(30*"-",time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), epoch, 30*"-"))
356 |
357 | train_loss = train(config,
358 | model,
359 | dataloader=train_dataloader,
360 | loss_function=loss_function,
361 | optimizer=optimizer,
362 | scheduler=scheduler,
363 | scaler=scaler)
364 |
365 | print("Epoch: {}, Train Loss = {:.3f}, Lr = {:.6f}".format(epoch,
366 | train_loss,
367 | optimizer.param_groups[0]['lr']))
368 |
369 | # evaluate
370 | if (epoch % config.eval_every_n_epoch == 0 and epoch > 1) or epoch == config.epochs:
371 |
372 | print("\n{}[{}]{}".format(30 * "-", "Evaluate", 30 * "-"))
373 |
374 | r1_test = evaluate(config=config,
375 | model=model,
376 | query_loader=query_dataloader_test,
377 | gallery_loader=gallery_dataloader_test,
378 | ranks=[1, 5, 10],
379 | step_size=1000,
380 | cleanup=True)
381 |
382 | if r1_test > best_score:
383 |
384 | best_score = r1_test
385 |
386 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
387 | torch.save(model.module.state_dict(),
388 | '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test))
389 | else:
390 | torch.save(model.state_dict(), '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test))
391 | elif r1_test > 26.0:
392 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
393 | torch.save(model.module.state_dict(),
394 | '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test))
395 | else:
396 | torch.save(model.state_dict(), '{}/weights_e{}_{:.4f}.pth'.format(model_path, epoch, r1_test))
397 |
398 | if config.custom_sampling:
399 | train_dataloader.dataset.shuffle()
400 |
401 | if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
402 | torch.save(model.module.state_dict(), '{}/weights_end.pth'.format(model_path))
403 | else:
404 | torch.save(model.state_dict(), '{}/weights_end.pth'.format(model_path))
405 |
--------------------------------------------------------------------------------