├── .gitignore ├── configs └── crowd_sense.json ├── dist_train.sh ├── eval.py ├── inference.py ├── misc ├── saver_builder.py └── tools.py ├── models ├── backbone.py ├── cropper.py ├── extractor.py ├── fuse.py ├── model.py ├── roi.py ├── transformer.py ├── tri_cropper.py └── tri_sim_ot_b.py ├── optimizer ├── __init__.py ├── optimizer_builder.py └── scheduler_builder.py ├── readme.md ├── requirements.txt ├── result └── video_results_test.json ├── statics └── imgs │ ├── c3.gif │ ├── c4.gif │ ├── p2.gif │ ├── p3.gif │ ├── pipeline.png │ ├── sim.mp4 │ └── tinywow_sim_58930123.gif ├── train.py ├── tri_dataset.py ├── tri_eingine.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python gitignore 2 | *.pyc 3 | outputs/* 4 | datasets/* -------------------------------------------------------------------------------- /configs/crowd_sense.json: -------------------------------------------------------------------------------- 1 | { 2 | "Dataset": { 3 | "train": { 4 | "root": "datasets/sensecrowd/train", 5 | "ann_dir": "datasets/sensecrowd/new_annotations", 6 | "size_divisor": 32, 7 | "batch_size": 1, 8 | "num_workers": 8, 9 | "shuffle": true, 10 | "drop_last": true, 11 | "cache_mode": false, 12 | "max_len": 3000 13 | }, 14 | "val": { 15 | "root": "datasets/sensecrowd/val", 16 | "ann_dir": "datasets/sensecrowd/new_annotations", 17 | "size_divisor": 32, 18 | "batch_size": 1, 19 | "num_workers": 8, 20 | "shuffle": false, 21 | "drop_last": false, 22 | "cache_mode": false, 23 | "max_len": 3000 24 | }, 25 | "test": { 26 | "root": "datasets/sensecrowd/test", 27 | "ann_dir": "datasets/sensecrowd/new_annotations", 28 | "size_divisor": 32, 29 | "batch_size": 1, 30 | "num_workers": 8, 31 | "shuffle": false, 32 | "drop_last": false, 33 | "cache_mode": false, 34 | "max_len": 3000 35 | } 36 | }, 37 | "Optimizer": { 38 | "type": "AdamW", 39 | "lr": 0.0001, 40 | "betas": [ 41 | 0.9, 42 | 0.999 43 | ], 44 | "eps": 1e-08, 45 | "weight_decay": 0.000001 46 | }, 47 | "Scheduler": { 48 | "type": "cosine", 49 | "T_max": 200, 50 | "eta_min":0.000000001, 51 | "ema":false, 52 | "ema_annel_strategy": "cos", 53 | "ema_annel_epochs":10, 54 | "ema_lr":0.000000001, 55 | "ema_weight":0.9, 56 | "ema_start_epoch":90 57 | }, 58 | "Saver": { 59 | "save_dir": "./outputs", 60 | "save_interval": 1, 61 | "save_start_epoch": 0, 62 | "save_num_per_epoch": 2, 63 | "max_save_num": 20, 64 | "save_best": true, 65 | "metric":"all_match_acc", 66 | "reverse": true 67 | }, 68 | "Logger": { 69 | "delimiter": "\t", 70 | "print_freq": 25, 71 | "header": "" 72 | }, 73 | "Misc": { 74 | "epochs":201, 75 | "use_tensorboard": true, 76 | "tensorboard_dir": "./outputs", 77 | "clip_max_norm": 10, 78 | "val_freq":1 79 | }, 80 | "Drawer": { 81 | "draw_freq": 25, 82 | "output_dir": "./outputs", 83 | "draw_original": true, 84 | "draw_denseMap": true, 85 | "draw_output": true, 86 | "mean": [ 87 | 0.485, 88 | 0.456, 89 | 0.406 90 | ], 91 | "std": [ 92 | 0.229, 93 | 0.224, 94 | 0.225 95 | ] 96 | } 97 | } -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --master_port 29505 --nproc_per_node=$NUM_PROC train.py -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | with open("result/video_results_test.json","r") as f: 5 | video_results=json.load(f) 6 | 7 | anno_root="datasets/sensecrowd/new_annotations" 8 | 9 | 10 | gt_video_num_list=[] 11 | gt_video_len_list=[] 12 | pred_video_num_list=[] 13 | pred_matched_num_list=[] 14 | gt_matched_num_list=[] 15 | for video_name in video_results: 16 | video_len=0 17 | anno_path=os.path.join(anno_root,video_name+".txt") 18 | with open(anno_path,"r") as f: 19 | lines=f.readlines() 20 | all_ids=set() 21 | 22 | for line in lines: 23 | line=line.strip().split(" ") 24 | data=[float(x) for x in line[3:] if x!=""] 25 | if len(data)>0: 26 | data=np.array(data) 27 | data=np.reshape(data,(-1,7)) 28 | ids=data[:,6].reshape(-1,1) 29 | for id in ids: 30 | all_ids.add(int(id[0])) 31 | info=video_results[video_name] 32 | gt_video_num=len(all_ids) 33 | pred_video_num=info["video_num"] 34 | pred_video_num_list.append(pred_video_num) 35 | gt_video_num_list.append(gt_video_num) 36 | gt_video_len_list.append(info["frame_num"]) 37 | 38 | 39 | 40 | MAE=np.mean(np.abs(np.array(gt_video_num_list)-np.array(pred_video_num_list))) 41 | MSE=np.mean(np.square(np.array(gt_video_num_list)-np.array(pred_video_num_list))) 42 | WRAE=np.sum(np.abs(np.array(gt_video_num_list)-np.array(pred_video_num_list))*np.array(gt_video_len_list)/np.array(gt_video_num_list)/np.sum(gt_video_len_list)) 43 | RMSE=np.sqrt(MSE) 44 | 45 | print(f"MAE:{MAE:.2f}, MSE:{MSE:.2f}, WRAE:{WRAE*100:.2f}%, RMSE:{RMSE:.2f}") -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import shutil 5 | import time 6 | from asyncio.log import logger 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from easydict import EasyDict as edict 11 | from termcolor import cprint 12 | from torch.cuda.amp import GradScaler 13 | from torch.nn import SyncBatchNorm 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from misc import tools 19 | from misc.saver_builder import Saver 20 | from misc.tools import MetricLogger, is_main_process 21 | from models.tri_cropper import build_model 22 | from tri_dataset import build_video_dataset as build_dataset 23 | from tri_dataset import inverse_normalize 24 | # from eingine.densemap_trainer import evaluate_counting, train_one_epoch 25 | from tri_eingine import evaluate_similarity, train_one_epoch 26 | from models.tri_sim_ot_b import similarity_cost 27 | import numpy as np 28 | from scipy.optimize import linear_sum_assignment 29 | from tqdm import tqdm 30 | import cv2 31 | 32 | torch.backends.cudnn.enabled = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | def read_pts(path): 37 | with open(path, "r") as f: 38 | lines = f.readlines() 39 | pts = [] 40 | for line in lines: 41 | line = line.strip().split(",") 42 | pts.append([float(line[0]), float(line[1])]) 43 | pts = np.array(pts) 44 | return pts 45 | 46 | 47 | def module2model(module_state_dict): 48 | state_dict = {} 49 | for k, v in module_state_dict.items(): 50 | while k.startswith("module."): 51 | k = k[7:] 52 | # while apply ema model 53 | if k == "n_averaged": 54 | print(f"{k}:{v}") 55 | continue 56 | state_dict[k] = v 57 | return state_dict 58 | 59 | 60 | def main(pair_cfg, pair_ckpt): 61 | tools.init_distributed_mode(pair_cfg,) 62 | tools.set_randomseed(42 + tools.get_rank()) 63 | # initilize the model 64 | model = model_without_ddp = build_model() 65 | model.load_state_dict(module2model(torch.load(pair_ckpt)["model"])) 66 | model.cuda() 67 | if pair_cfg.distributed: 68 | sync_model = SyncBatchNorm.convert_sync_batchnorm(model) 69 | model = torch.nn.parallel.DistributedDataParallel( 70 | sync_model, device_ids=[pair_cfg.gpu], find_unused_parameters=False) 71 | model_without_ddp = model.module 72 | 73 | # build the dataset and dataloader 74 | dataset = build_dataset(pair_cfg.Dataset.test.root, 75 | pair_cfg.Dataset.test.ann_dir) 76 | sampler = DistributedSampler( 77 | dataset, shuffle=False) if pair_cfg.distributed else None 78 | loader = DataLoader(dataset, 79 | batch_size=pair_cfg.Dataset.val.batch_size, 80 | sampler=sampler, 81 | shuffle=False, 82 | num_workers=pair_cfg.Dataset.val.num_workers, 83 | pin_memory=True) 84 | model.eval() 85 | video_results = {} 86 | interval = 15 87 | ttl = 5 88 | max_mem=5 89 | threshold=0.4 90 | with torch.no_grad(): 91 | for imgs, labels in tqdm(loader): 92 | cnt_list = [] 93 | video_name = labels["video_name"][0] 94 | img_names = labels["img_names"] 95 | w, h = labels["w"][0], labels["h"][0] 96 | 97 | img_name0 = img_names[0][0] 98 | pos_path0 = os.path.join( 99 | "locater/results", video_name, img_name0+".txt") 100 | print(pos_path0) 101 | pos0 = read_pts(pos_path0) 102 | z0 = model.forward_single_image( 103 | imgs[0, 0].cuda().unsqueeze(0), [pos0], True)[0] 104 | cnt_0 = len(pos0) 105 | cum_cnt = cnt_0 106 | cnt_list.append(cnt_0) 107 | selected_idx = [v for v in range( 108 | interval, len(img_names), interval)] 109 | pos_lists = [] 110 | inflow_lists = [] 111 | pos_lists.append(pos0) 112 | inflow_lists.append([1 for _ in range(len(pos0))]) 113 | memory_features = [[z0[i]] for i in range(len(pos0))] 114 | ttl_list=[ttl for _ in range(len(pos0))] 115 | # if selected_idx[-1] != len(img_names)-1: 116 | # selected_idx.append(len(img_names)-1) 117 | for i in selected_idx: 118 | img_name = img_names[i][0] 119 | pos_path = os.path.join( 120 | "locater/results", video_name, img_name+".txt") 121 | pos = read_pts(pos_path) 122 | z = model.forward_single_image( 123 | imgs[0, i].cuda().unsqueeze(0), [pos], True)[0] 124 | z = F.normalize(z, dim=-1) 125 | C = np.zeros((len(pos), len(memory_features))) 126 | for idx, pre_z in enumerate(memory_features): 127 | pre_z = torch.stack(pre_z[-1:], dim=0).unsqueeze(0) 128 | pre_z = F.normalize(pre_z, dim=-1) 129 | sim_cost = torch.bmm(pre_z, z.unsqueeze(0).transpose(1, 2)) 130 | # sim_cost=1-similarity_cost(pre_z,z.unsqueeze(0)) 131 | sim_cost = sim_cost.cpu().numpy()[0] 132 | sim_cost = np.min(sim_cost, axis=0) 133 | C[:, idx] = sim_cost 134 | 135 | row_ind, col_ind = linear_sum_assignment(-C) 136 | sim_score = C[row_ind, col_ind] 137 | shared_mask = sim_score > threshold 138 | ori_shared_idx_list = col_ind[shared_mask] 139 | new_shared_idx_list = row_ind[shared_mask] 140 | outflow_idx_list = [i for i in range(len(pos)) if i not in row_ind[shared_mask]] 141 | 142 | for ori_idx, new_idx in zip(ori_shared_idx_list, new_shared_idx_list): 143 | memory_features[ori_idx].append(z[new_idx]) 144 | ttl_list[ori_idx]=ttl 145 | for idx in outflow_idx_list: 146 | memory_features.append([z[idx]]) 147 | ttl_list.append(ttl) 148 | pos_lists.append(pos) 149 | inflow_list = [] 150 | for j in range(len(pos)): 151 | if j in outflow_idx_list: 152 | inflow_list.append(1) 153 | else: 154 | inflow_list.append(0) 155 | inflow_lists.append(inflow_list) 156 | cum_cnt += len(outflow_idx_list) 157 | cnt_list.append(len(outflow_idx_list)) 158 | ttl_list=[ttl_list[idx]-1 for idx in range(len(ttl_list))] 159 | for idx in range(len(ttl_list)-1,-1,-1): 160 | if ttl_list[idx]==0: 161 | del memory_features[idx] 162 | del ttl_list[idx] 163 | 164 | # conver numpy to list 165 | pos_lists = [pos_lists[i].tolist() for i in range(len(pos_lists))] 166 | 167 | video_results[video_name] = { 168 | "video_num": cum_cnt, 169 | "first_frame_num": cnt_0, 170 | "cnt_list": cnt_list, 171 | "frame_num": len(img_names), 172 | "pos_lists": pos_lists, 173 | "inflow_lists": inflow_lists, 174 | 175 | } 176 | print(video_name, video_results[video_name]) 177 | with open("video_results_test.json", "w") as f: 178 | json.dump(video_results, f, indent=4) 179 | 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser("DenseMap Head ") 183 | parser.add_argument("--pair_config", default="configs/crowd_sense.json") 184 | parser.add_argument( 185 | "--pair_ckpt", default="outputs/weights/best.pth") 186 | 187 | parser.add_argument("--local_rank", type=int) 188 | args = parser.parse_args() 189 | 190 | if os.path.exists(args.pair_config): 191 | with open(args.pair_config, "r") as f: 192 | pair_configs = json.load(f) 193 | pair_cfg = edict(pair_configs) 194 | 195 | strtime = time.strftime('%Y%m%d%H%M') + "_" + os.path.basename( 196 | args.pair_config)[:-5] 197 | 198 | output_path = os.path.join(pair_cfg.Misc.tensorboard_dir, strtime) 199 | 200 | pair_cfg.Misc.tensorboard_dir = output_path 201 | pair_ckpt = args.pair_ckpt 202 | main(pair_cfg, pair_ckpt) 203 | -------------------------------------------------------------------------------- /misc/saver_builder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | from .tools import is_main_process 5 | class Saver(): 6 | def __init__(self, args) -> None: 7 | self.save_dir=args.save_dir 8 | self.save_interval=args.save_interval 9 | if not os.path.exists(self.save_dir): 10 | os.makedirs(self.save_dir,exist_ok=True) 11 | self.save_best=args.save_best 12 | self.save_start_epoch = args.save_start_epoch 13 | self.min_value=0 14 | self.max_value=1e10 15 | self.reverse=args.reverse 16 | self.metric=args.metric 17 | 18 | def save(self, model, optimizer, scheduler, filename, epoch, stats={}): 19 | if is_main_process(): 20 | torch.save({ 21 | 'model': model.state_dict(), 22 | 'optimizer': optimizer.state_dict(), 23 | 'scheduler': scheduler.state_dict(), 24 | 'epoch': epoch, 25 | 'states':stats 26 | },os.path.join(self.save_dir, filename)) 27 | 28 | def save_inter(self, model, optimizer, scheduler, name, epoch, stats={}): 29 | if epoch % self.save_interval == 0 and self.save_start_epoch <= epoch: 30 | self.save(model, optimizer, scheduler, name, epoch, stats) 31 | 32 | def save_on_master(self, model, optimizer, scheduler, epoch, stats={}): 33 | if is_main_process(): 34 | self.save_inter(model, optimizer, scheduler, f"checkpoint{epoch:04}.pth", epoch, stats) 35 | 36 | if self.save_best and self.save_start_epoch <= epoch: 37 | if self.reverse and stats.test_stats[self.metric] > self.min_value: 38 | self.min_value=max(self.min_value,stats.test_stats[self.metric]) 39 | self.save(model, optimizer, scheduler, f"best.pth", epoch, stats) 40 | elif not self.reverse and stats.test_stats[self.metric] < self.max_value: 41 | self.max_value=min(self.max_value,stats.test_stats[self.metric]) 42 | self.save(model, optimizer, scheduler, f"best.pth", epoch, stats) 43 | 44 | def save_last(self, model, optimizer, scheduler, epoch, stats={}): 45 | if is_main_process(): 46 | self.save(model, optimizer, scheduler, f"checkpoint_last.pth", epoch,stats) -------------------------------------------------------------------------------- /misc/tools.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import os 4 | import pickle 5 | import random 6 | import subprocess 7 | import time 8 | from collections import defaultdict, deque 9 | from typing import List, Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | import torch.nn as nn 15 | # needed due to empty tensor bug in pytorch and torchvision 0.5 16 | import torchvision 17 | from torch import Tensor 18 | 19 | def set_randomseed(seed): 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], 50 | dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | try: 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | except: 58 | print(f"Warning: {self.count} {self.total} is not an integer") 59 | 60 | self.count = 1 61 | self.total = -1 62 | 63 | @property 64 | def median(self): 65 | d = torch.tensor(list(self.deque)) 66 | return d.median().item() 67 | 68 | @property 69 | def avg(self): 70 | d = torch.tensor(list(self.deque), dtype=torch.float32) 71 | return d.mean().item() 72 | 73 | @property 74 | def global_avg(self): 75 | return self.total / self.count 76 | @property 77 | def max(self): 78 | return max(self.deque) 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | def __str__(self): 85 | return self.fmt.format( 86 | median=self.median, 87 | avg=self.avg, 88 | global_avg=self.global_avg, 89 | max=self.max, 90 | value=self.value) 91 | 92 | 93 | def all_gather(data): 94 | """ 95 | Run all_gather on arbitrary picklable data (not necessarily tensors) 96 | Args: 97 | data: any picklable object 98 | Returns: 99 | list[data]: list of data gathered from each rank 100 | """ 101 | world_size = get_world_size() 102 | if world_size == 1: 103 | return [data] 104 | 105 | # serialized to a Tensor 106 | buffer = pickle.dumps(data) 107 | storage = torch.ByteStorage.from_buffer(buffer) 108 | tensor = torch.ByteTensor(storage).to("cuda") 109 | 110 | # obtain Tensor size of each rank 111 | local_size = torch.tensor([tensor.numel()], device="cuda") 112 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 113 | dist.all_gather(size_list, local_size) 114 | size_list = [int(size.item()) for size in size_list] 115 | max_size = max(size_list) 116 | 117 | # receiving Tensor from all ranks 118 | # we pad the tensor because torch all_gather does not support 119 | # gathering tensors of different shapes 120 | tensor_list = [] 121 | for _ in size_list: 122 | tensor_list.append(torch.empty( 123 | (max_size,), dtype=torch.uint8, device="cuda")) 124 | if local_size != max_size: 125 | padding = torch.empty(size=(max_size - local_size,), 126 | dtype=torch.uint8, device="cuda") 127 | tensor = torch.cat((tensor, padding), dim=0) 128 | dist.all_gather(tensor_list, tensor) 129 | 130 | data_list = [] 131 | for size, tensor in zip(size_list, tensor_list): 132 | buffer = tensor.cpu().numpy().tobytes()[:size] 133 | data_list.append(pickle.loads(buffer)) 134 | 135 | return data_list 136 | 137 | 138 | def reduce_dict(input_dict, average=True): 139 | """ 140 | Args: 141 | input_dict (dict): all the values will be reduced 142 | average (bool): whether to do average or sum 143 | Reduce the values in the dictionary from all processes so that all processes 144 | have the averaged results. Returns a dict with the same fields as 145 | input_dict, after reduction. 146 | """ 147 | world_size = get_world_size() 148 | if world_size < 2: 149 | return input_dict 150 | with torch.no_grad(): 151 | names = [] 152 | values = [] 153 | # sort the keys so that they are consistent across processes 154 | for k in sorted(input_dict.keys()): 155 | names.append(k) 156 | values.append(input_dict[k]) 157 | values = torch.stack(values, dim=0) 158 | dist.all_reduce(values) 159 | if average: 160 | values /= world_size 161 | reduced_dict = {k: v for k, v in zip(names, values)} 162 | return reduced_dict 163 | 164 | 165 | class MetricLogger(object): 166 | def __init__(self, args): 167 | self.meters = defaultdict(SmoothedValue) 168 | self.delimiter = args.delimiter 169 | self.print_freq = args.print_freq 170 | self.header = args.header 171 | 172 | def update(self, **kwargs): 173 | for k, v in kwargs.items(): 174 | if isinstance(v, torch.Tensor): 175 | v = v.item() 176 | assert isinstance(v, (float, int)) 177 | self.meters[k].update(v) 178 | 179 | def __getattr__(self, attr): 180 | if attr in self.meters: 181 | return self.meters[attr] 182 | if attr in self.__dict__: 183 | return self.__dict__[attr] 184 | raise AttributeError("'{}' object has no attribute '{}'".format( 185 | type(self).__name__, attr)) 186 | 187 | def __str__(self): 188 | loss_str = [] 189 | for name, meter in self.meters.items(): 190 | loss_str.append( 191 | "{}: {}".format(name, str(meter)) 192 | ) 193 | return self.delimiter.join(loss_str) 194 | 195 | def synchronize_between_processes(self): 196 | for meter in self.meters.values(): 197 | meter.synchronize_between_processes() 198 | 199 | def add_meter(self, name, meter): 200 | self.meters[name] = meter 201 | 202 | def set_header(self, header): 203 | self.header = header 204 | 205 | def log_every(self, iterable): 206 | i = 1 207 | 208 | start_time = time.time() 209 | end = time.time() 210 | iter_time = SmoothedValue(fmt='{avg:.4f}') 211 | data_time = SmoothedValue(fmt='{avg:.4f}') 212 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 213 | if torch.cuda.is_available(): 214 | log_msg = self.delimiter.join([ 215 | self.header, 216 | '[{0' + space_fmt + '}/{1}]', 217 | 'eta: {eta}', 218 | '{meters}', 219 | 'time: {time}', 220 | 'data: {data}', 221 | 'max mem: {memory:.0f}' 222 | ]) 223 | else: 224 | log_msg = self.delimiter.join([ 225 | self.header, 226 | '[{0' + space_fmt + '}/{1}]', 227 | 'eta: {eta}', 228 | '{meters}', 229 | 'time: {time}', 230 | 'data: {data}' 231 | ]) 232 | MB = 1024.0 * 1024.0 233 | for obj in iterable: 234 | data_time.update(time.time() - end) 235 | yield obj 236 | iter_time.update(time.time() - end) 237 | if i % self.print_freq == 0 or i == len(iterable) - 1: 238 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 239 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 240 | if torch.cuda.is_available(): 241 | print(log_msg.format( 242 | i, len(iterable), eta=eta_string, 243 | meters=str(self), 244 | time=str(iter_time), data=str(data_time), 245 | memory=torch.cuda.max_memory_allocated() / MB)) 246 | else: 247 | print(log_msg.format( 248 | i, len(iterable), eta=eta_string, 249 | meters=str(self), 250 | time=str(iter_time), data=str(data_time))) 251 | i += 1 252 | end = time.time() 253 | total_time = time.time() - start_time 254 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 255 | print('{} Total time: {} ({:.4f} s / it)'.format( 256 | self.header, total_time_str, total_time / len(iterable))) 257 | 258 | 259 | def get_sha(): 260 | cwd = os.path.dirname(os.path.abspath(__file__)) 261 | 262 | def _run(command): 263 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 264 | sha = 'N/A' 265 | diff = "clean" 266 | branch = 'N/A' 267 | try: 268 | sha = _run(['git', 'rev-parse', 'HEAD']) 269 | subprocess.check_output(['git', 'diff'], cwd=cwd) 270 | diff = _run(['git', 'diff-index', 'HEAD']) 271 | diff = "has uncommited changes" if diff else "clean" 272 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 273 | except Exception: 274 | pass 275 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 276 | return message 277 | 278 | 279 | def collate_fn(batch): 280 | batch = list(zip(*batch)) 281 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 282 | return tuple(batch) 283 | 284 | 285 | def _max_by_axis(the_list): 286 | # type: (List[List[int]]) -> List[int] 287 | maxes = the_list[0] 288 | for sublist in the_list[1:]: 289 | for index, item in enumerate(sublist): 290 | maxes[index] = max(maxes[index], item) 291 | return maxes 292 | 293 | 294 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 295 | # TODO make this more general 296 | if tensor_list[0].ndim == 3: 297 | # TODO make it support different-sized images 298 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 299 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 300 | batch_shape = [len(tensor_list)] + max_size 301 | b, c, h, w = batch_shape 302 | dtype = tensor_list[0].dtype 303 | device = tensor_list[0].device 304 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 305 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 306 | for img, pad_img, m in zip(tensor_list, tensor, mask): 307 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 308 | m[: img.shape[1], :img.shape[2]] = False 309 | else: 310 | raise ValueError('not supported') 311 | return NestedTensor(tensor, mask) 312 | 313 | 314 | class NestedTensor(object): 315 | def __init__(self, tensors, mask: Optional[Tensor]): 316 | self.tensors = tensors 317 | self.mask = mask 318 | 319 | def to(self, device, non_blocking=False): 320 | ## type: (Device) -> NestedTensor # noqa 321 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 322 | mask = self.mask 323 | if mask is not None: 324 | assert mask is not None 325 | cast_mask = mask.to(device, non_blocking=non_blocking) 326 | else: 327 | cast_mask = None 328 | return NestedTensor(cast_tensor, cast_mask) 329 | 330 | def record_stream(self, *args, **kwargs): 331 | self.tensors.record_stream(*args, **kwargs) 332 | if self.mask is not None: 333 | self.mask.record_stream(*args, **kwargs) 334 | 335 | def decompose(self): 336 | return self.tensors, self.mask 337 | 338 | def __repr__(self): 339 | return str(self.tensors) 340 | 341 | 342 | def setup_for_distributed(is_master): 343 | """ 344 | This function disables printing when not in master process 345 | """ 346 | import builtins as __builtin__ 347 | builtin_print = __builtin__.print 348 | 349 | def print(*args, **kwargs): 350 | force = kwargs.pop('force', False) 351 | if is_master or force: 352 | builtin_print(*args, **kwargs) 353 | 354 | __builtin__.print = print 355 | 356 | 357 | def is_dist_avail_and_initialized(): 358 | if not dist.is_available(): 359 | return False 360 | if not dist.is_initialized(): 361 | return False 362 | return True 363 | 364 | 365 | def get_world_size(): 366 | if not is_dist_avail_and_initialized(): 367 | return 1 368 | return dist.get_world_size() 369 | 370 | 371 | def get_rank(): 372 | if not is_dist_avail_and_initialized(): 373 | return 0 374 | return dist.get_rank() 375 | 376 | 377 | def get_local_size(): 378 | if not is_dist_avail_and_initialized(): 379 | return 1 380 | return int(os.environ['LOCAL_SIZE']) 381 | 382 | 383 | def get_local_rank(): 384 | if not is_dist_avail_and_initialized(): 385 | return 0 386 | return int(os.environ['LOCAL_RANK']) 387 | 388 | 389 | def is_main_process(): 390 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 391 | return int(os.environ['RANK']) == 0 392 | elif "SLURM_PROCID" in os.environ: 393 | return int(os.environ["SLURM_PROCID"]) == 0 394 | else: 395 | return True 396 | 397 | 398 | def save_on_master(*args, **kwargs): 399 | if is_main_process(): 400 | torch.save(*args, **kwargs) 401 | 402 | 403 | def init_distributedrun_mode(args): 404 | if "LOCAL_RANK" in os.environ: 405 | args.local_rank = int(os.environ["LOCAL_RANK"]) 406 | args.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 407 | 408 | 409 | def init_distributed_mode(args): 410 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 411 | args.rank = int(os.environ["RANK"]) 412 | args.world_size = int(os.environ['WORLD_SIZE']) 413 | args.gpu = int(os.environ['LOCAL_RANK']) 414 | args.dist_url = 'env://' 415 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 416 | elif 'SLURM_PROCID' in os.environ: 417 | proc_id = int(os.environ['SLURM_PROCID']) 418 | ntasks = int(os.environ['SLURM_NTASKS']) 419 | node_list = os.environ['SLURM_NODELIST'] 420 | num_gpus = torch.cuda.device_count() 421 | addr = subprocess.getoutput( 422 | 'scontrol show hostname {} | head -n1'.format(node_list)) 423 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 424 | os.environ['MASTER_ADDR'] = addr 425 | os.environ['WORLD_SIZE'] = str(ntasks) 426 | os.environ['RANK'] = str(proc_id) 427 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 428 | os.environ['LOCAL_SIZE'] = str(num_gpus) 429 | args.dist_url = 'env://' 430 | args.world_size = ntasks 431 | args.rank = proc_id 432 | args.gpu = proc_id % num_gpus 433 | else: 434 | print('Not using distributed mode') 435 | args.distributed = False 436 | args.gpu=0 437 | return 438 | 439 | args.distributed = True 440 | 441 | torch.cuda.set_device(args.gpu) 442 | args.dist_backend = 'nccl' 443 | print('| distributed init (rank {}): {}'.format( 444 | args.rank, args.dist_url), flush=True) 445 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 446 | world_size=args.world_size, rank=args.rank) 447 | torch.distributed.barrier() 448 | setup_for_distributed(args.rank == 0) 449 | 450 | 451 | @torch.no_grad() 452 | def accuracy(output, target, topk=(1,)): 453 | """Computes the precision@k for the specified values of k""" 454 | if target.numel() == 0: 455 | return [torch.zeros([], device=output.device)] 456 | maxk = max(topk) 457 | batch_size = target.size(0) 458 | 459 | _, pred = output.topk(maxk, 1, True, True) 460 | pred = pred.t() 461 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 462 | 463 | res = [] 464 | for k in topk: 465 | correct_k = correct[:k].view(-1).float().sum(0) 466 | res.append(correct_k.mul_(100.0 / batch_size)) 467 | return res 468 | 469 | 470 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 471 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 472 | """ 473 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 474 | This will eventually be supported natively by PyTorch, and this 475 | class can go away. 476 | """ 477 | # if float(torchvision.__version__[:3]) < 0.7: 478 | # if input.numel() > 0: 479 | # return torch.nn.functional.interpolate( 480 | # input, size, scale_factor, mode, align_corners 481 | # ) 482 | 483 | # output_shape = _output_size(2, input, size, scale_factor) 484 | # output_shape = list(input.shape[:-2]) + list(output_shape) 485 | # if float(torchvision.__version__[:3]) < 0.5: 486 | # return _NewEmptyTensorOp.apply(input, output_shape) 487 | # return _new_empty_tensor(input, output_shape) 488 | # else: 489 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 490 | 491 | 492 | def get_total_grad_norm(parameters, norm_type=2): 493 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 494 | norm_type = float(norm_type) 495 | device = parameters[0].grad.device 496 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 497 | norm_type) 498 | return total_norm 499 | 500 | 501 | def inverse_sigmoid(x, eps=1e-5): 502 | x = x.clamp(min=0, max=1) 503 | x1 = x.clamp(min=eps) 504 | x2 = (1 - x).clamp(min=eps) 505 | return torch.log(x1/x2) 506 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from typing import Dict, List 6 | from timm import create_model 7 | from timm.models import features 8 | import torch 9 | 10 | class Backbone(nn.Module): 11 | def __init__(self, name: str,pretrained:bool,out_indices:List[int], train_backbone: bool): 12 | super(Backbone,self).__init__() 13 | backbone=create_model(name,pretrained=pretrained,features_only=True, out_indices=out_indices) 14 | self.train_backbone = train_backbone 15 | self.backbone=backbone 16 | self.out_indices=out_indices 17 | if not self.train_backbone: 18 | for name, parameter in self.backbone.named_parameters(): 19 | parameter.requires_grad_(False) 20 | def forward(self,x): 21 | x=self.backbone(x) 22 | for i in range(len(x)): 23 | x[i]=F.relu(x[i]) 24 | 25 | return x 26 | 27 | @property 28 | def feature_info(self): 29 | return features._get_feature_info(self.backbone,out_indices=self.out_indices) 30 | 31 | 32 | def build_backbone(): 33 | backbone = Backbone("convnext_small_384_in22ft1k", True, [3], True) 34 | # backbone = Backbone("convnext_small_384_in22ft1k", True, [2], True) 35 | 36 | # backbone = Backbone("resnet50.a3_in1k", True, [2], True) 37 | 38 | return backbone 39 | 40 | if __name__=="__main__": 41 | model=build_backbone() 42 | x=torch.randn(1,3,224,224) 43 | z=model(x) 44 | for f in z: 45 | print(f.shape) 46 | -------------------------------------------------------------------------------- /models/cropper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | 24 | self.ot_loss = GML() 25 | 26 | def forward(self, input): 27 | x = input["image_pair"] 28 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 29 | x1 = x[:, 0:3, :, :] 30 | x2 = x[:, 3:6, :, :] 31 | ref_point1 = ref_points[:, 0, ...] 32 | ref_point2 = ref_points[:, 1, ...] 33 | 34 | z1_lists=[] 35 | z2_lists=[] 36 | for b in range(x1.shape[0]): 37 | z1_list=[] 38 | z2_list=[] 39 | for pt1,pt2 in zip(ref_point1[b],ref_point2[b]): 40 | z1=self.get_crops(x1[b].unsqueeze(0),pt1) 41 | z2=self.get_crops(x2[b].unsqueeze(0),pt2) 42 | z1=F.interpolate(z1,(224,224)) 43 | z2=F.interpolate(z2,(224,224)) 44 | z1_list.append(z1) 45 | z2_list.append(z2) 46 | z1=torch.cat(z1_list,dim=0) 47 | z2=torch.cat(z2_list,dim=0) 48 | z1=self.backbone(z1)[0].flatten(2).flatten(1) 49 | z2=self.backbone(z2)[0].flatten(2).flatten(1) 50 | z1_lists.append(z1) 51 | z2_lists.append(z2) 52 | z1=torch.stack(z1_lists,dim=0) 53 | z2=torch.stack(z2_lists,dim=0) 54 | return z1,z2 55 | 56 | def get_crops(self, z, pt, window_size=64): 57 | h, w = z.shape[-2], z.shape[-1] 58 | x_min = pt[0]*w-window_size//2 59 | x_max = pt[0]*w+window_size//2 60 | y_min = pt[1]*h-window_size//2 61 | y_max = pt[1]*h+window_size//2 62 | x_min, x_max, y_min, y_max = int(x_min), int( 63 | x_max), int(y_min), int(y_max) 64 | x_min = max(0, x_min) 65 | x_max = min(w, x_max) 66 | y_min = max(0, y_min) 67 | y_max = min(h, y_max) 68 | z = z[:, :, y_min:y_max, x_min:x_max] 69 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 70 | # z = z + pos_emb 71 | # z=F.adaptive_avg_pool2d(z,(1,1)) 72 | # z=z.squeeze(3).squeeze(2) 73 | return z 74 | 75 | def loss(self, features1, features2): 76 | loss = self.ot_loss(features1, features2) 77 | loss_dict = {} 78 | loss_dict["all"] = loss 79 | return loss_dict 80 | 81 | 82 | def build_model(): 83 | model = Model() 84 | return model 85 | -------------------------------------------------------------------------------- /models/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .sim_ot import ot_similarity 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | def pos2posemb2d(pos, num_pos_feats=128, temperature=1000): 12 | scale = 2 * math.pi 13 | pos = pos * scale 14 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 15 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 16 | pos_x = pos[..., 0, None] / dim_t 17 | pos_y = pos[..., 1, None] / dim_t 18 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 19 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 20 | posemb = torch.cat((pos_y, pos_x), dim=-1) 21 | return posemb 22 | 23 | class PositionEmbeddingSine(nn.Module): 24 | """ 25 | This is a more standard version of the position embedding, very similar to the one 26 | used by the Attention is all you need paper, generalized to work on images. 27 | """ 28 | def __init__(self, num_pos_feats=128, temperature=1000, normalize=False, scale=None): 29 | super().__init__() 30 | self.num_pos_feats = num_pos_feats 31 | self.temperature = temperature 32 | self.normalize = normalize 33 | if scale is not None and normalize is False: 34 | raise ValueError("normalize should be True if scale is passed") 35 | if scale is None: 36 | scale = 2 * math.pi 37 | self.scale = scale 38 | 39 | def forward(self, x): 40 | mask = torch.ones(x.shape[0],x.shape[1],x.shape[2]).cuda().bool() 41 | assert mask is not None 42 | not_mask = ~mask 43 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 44 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 45 | if self.normalize: 46 | eps = 1e-6 47 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 48 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 49 | 50 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 51 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 52 | 53 | pos_x = x_embed[:, :, :, None] / dim_t 54 | pos_y = y_embed[:, :, :, None] / dim_t 55 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 57 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 58 | return pos 59 | 60 | def _get_activation_fn(activation): 61 | """Return an activation function given a string""" 62 | if activation == "relu": 63 | return F.relu 64 | if activation == "gelu": 65 | return F.gelu 66 | if activation == "glu": 67 | return F.glu 68 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 69 | 70 | def _get_clones(module, N): 71 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 72 | 73 | 74 | class FFN(nn.Module): 75 | 76 | def __init__(self, d_model=256, d_ffn=1024, dropout=0.): 77 | super().__init__() 78 | self.linear1 = nn.Linear(d_model, d_ffn) 79 | self.activation = nn.ReLU() 80 | self.dropout2 = nn.Dropout(dropout) 81 | self.linear2 = nn.Linear(d_ffn, d_model) 82 | self.dropout3 = nn.Dropout(dropout) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | 85 | def forward(self, src): 86 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 87 | src = src + self.dropout3(src2) 88 | src = self.norm2(src) 89 | return src 90 | 91 | class Transformer(nn.Module): 92 | 93 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 94 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 95 | activation="relu", normalize_before=False, 96 | return_intermediate_dec=False): 97 | super().__init__() 98 | 99 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 100 | dropout, activation, normalize_before) 101 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 102 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 103 | 104 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 105 | dropout, activation, normalize_before) 106 | decoder_norm = nn.LayerNorm(d_model) 107 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 108 | return_intermediate=return_intermediate_dec) 109 | 110 | self._reset_parameters() 111 | 112 | self.d_model = d_model 113 | self.nhead = nhead 114 | 115 | def _reset_parameters(self): 116 | for p in self.parameters(): 117 | if p.dim() > 1: 118 | nn.init.xavier_uniform_(p) 119 | 120 | def forward(self, src, mask, query_embed, pos_embed): 121 | # flatten NxCxHxW to HWxNxC 122 | bs, c, h, w = src.shape 123 | src = src.flatten(2).permute(2, 0, 1) 124 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 125 | mask = mask.flatten(1) 126 | print(src.dtype) 127 | tgt = torch.zeros_like(query_embed) 128 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 129 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 130 | pos=pos_embed, query_pos=query_embed) 131 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 132 | 133 | 134 | class TransformerEncoder(nn.Module): 135 | 136 | def __init__(self, encoder_layer, num_layers, norm=None): 137 | super().__init__() 138 | self.layers = _get_clones(encoder_layer, num_layers) 139 | self.num_layers = num_layers 140 | self.norm = norm 141 | 142 | def forward(self, src, 143 | mask: Optional[Tensor] = None, 144 | src_key_padding_mask: Optional[Tensor] = None, 145 | pos: Optional[Tensor] = None): 146 | output = src 147 | 148 | for layer in self.layers: 149 | output = layer(output, src_mask=mask, 150 | src_key_padding_mask=src_key_padding_mask, pos=pos) 151 | 152 | if self.norm is not None: 153 | output = self.norm(output) 154 | 155 | return output 156 | 157 | 158 | class TransformerDecoder(nn.Module): 159 | 160 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 161 | super().__init__() 162 | self.layers = _get_clones(decoder_layer, num_layers) 163 | self.num_layers = num_layers 164 | self.norm = norm 165 | self.return_intermediate = return_intermediate 166 | 167 | def forward(self, tgt, memory, 168 | tgt_mask: Optional[Tensor] = None, 169 | memory_mask: Optional[Tensor] = None, 170 | tgt_key_padding_mask: Optional[Tensor] = None, 171 | memory_key_padding_mask: Optional[Tensor] = None, 172 | pos: Optional[Tensor] = None, 173 | query_pos: Optional[Tensor] = None): 174 | output = tgt 175 | 176 | intermediate = [] 177 | 178 | for layer in self.layers: 179 | output = layer(output, memory, tgt_mask=tgt_mask, 180 | memory_mask=memory_mask, 181 | tgt_key_padding_mask=tgt_key_padding_mask, 182 | memory_key_padding_mask=memory_key_padding_mask, 183 | pos=pos, query_pos=query_pos) 184 | if self.return_intermediate: 185 | intermediate.append(self.norm(output)) 186 | 187 | if self.norm is not None: 188 | output = self.norm(output) 189 | if self.return_intermediate: 190 | intermediate.pop() 191 | intermediate.append(output) 192 | 193 | if self.return_intermediate: 194 | return torch.stack(intermediate) 195 | 196 | return output.unsqueeze(0) 197 | 198 | 199 | class TransformerEncoderLayer(nn.Module): 200 | 201 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 202 | activation="relu", normalize_before=False): 203 | super().__init__() 204 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 205 | # Implementation of Feedforward model 206 | self.linear1 = nn.Linear(d_model, dim_feedforward) 207 | self.dropout = nn.Dropout(dropout) 208 | self.linear2 = nn.Linear(dim_feedforward, d_model) 209 | 210 | self.norm1 = nn.LayerNorm(d_model) 211 | self.norm2 = nn.LayerNorm(d_model) 212 | self.dropout1 = nn.Dropout(dropout) 213 | self.dropout2 = nn.Dropout(dropout) 214 | 215 | self.activation = _get_activation_fn(activation) 216 | self.normalize_before = normalize_before 217 | 218 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 219 | return tensor if pos is None else tensor + pos 220 | 221 | def forward_post(self, 222 | src, 223 | src_mask: Optional[Tensor] = None, 224 | src_key_padding_mask: Optional[Tensor] = None, 225 | pos: Optional[Tensor] = None): 226 | print(src.shape,pos.shape) 227 | q = k = self.with_pos_embed(src, pos) 228 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 229 | key_padding_mask=src_key_padding_mask)[0] 230 | src = src + self.dropout1(src2) 231 | src = self.norm1(src) 232 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 233 | src = src + self.dropout2(src2) 234 | src = self.norm2(src) 235 | return src 236 | 237 | def forward_pre(self, src, 238 | src_mask: Optional[Tensor] = None, 239 | src_key_padding_mask: Optional[Tensor] = None, 240 | pos: Optional[Tensor] = None): 241 | src2 = self.norm1(src) 242 | q = k = self.with_pos_embed(src2, pos) 243 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 244 | key_padding_mask=src_key_padding_mask)[0] 245 | src = src + self.dropout1(src2) 246 | src2 = self.norm2(src) 247 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 248 | src = src + self.dropout2(src2) 249 | return src 250 | 251 | def forward(self, src, 252 | src_mask: Optional[Tensor] = None, 253 | src_key_padding_mask: Optional[Tensor] = None, 254 | pos: Optional[Tensor] = None): 255 | if self.normalize_before: 256 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 257 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 258 | 259 | 260 | class TransformerDecoderLayer(nn.Module): 261 | 262 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 263 | activation="relu", normalize_before=False): 264 | super().__init__() 265 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 266 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 267 | # Implementation of Feedforward model 268 | self.linear1 = nn.Linear(d_model, dim_feedforward) 269 | self.dropout = nn.Dropout(dropout) 270 | self.linear2 = nn.Linear(dim_feedforward, d_model) 271 | 272 | self.norm1 = nn.LayerNorm(d_model) 273 | self.norm2 = nn.LayerNorm(d_model) 274 | self.norm3 = nn.LayerNorm(d_model) 275 | self.dropout1 = nn.Dropout(dropout) 276 | self.dropout2 = nn.Dropout(dropout) 277 | self.dropout3 = nn.Dropout(dropout) 278 | 279 | self.activation = _get_activation_fn(activation) 280 | self.normalize_before = normalize_before 281 | 282 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 283 | return tensor if pos is None else tensor + pos 284 | 285 | def forward_post(self, tgt, memory, 286 | tgt_mask: Optional[Tensor] = None, 287 | memory_mask: Optional[Tensor] = None, 288 | tgt_key_padding_mask: Optional[Tensor] = None, 289 | memory_key_padding_mask: Optional[Tensor] = None, 290 | pos: Optional[Tensor] = None, 291 | query_pos: Optional[Tensor] = None): 292 | q = k = self.with_pos_embed(tgt, query_pos) 293 | print(q.dtype,k.dtype,tgt.dtype) 294 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 295 | key_padding_mask=tgt_key_padding_mask)[0] 296 | tgt = tgt + self.dropout1(tgt2) 297 | tgt = self.norm1(tgt) 298 | print(tgt.shape,query_pos.shape,memory.shape,pos.shape) 299 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 300 | key=self.with_pos_embed(memory, pos), 301 | value=memory, attn_mask=memory_mask, 302 | key_padding_mask=memory_key_padding_mask)[0] 303 | tgt = tgt + self.dropout2(tgt2) 304 | tgt = self.norm2(tgt) 305 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 306 | tgt = tgt + self.dropout3(tgt2) 307 | tgt = self.norm3(tgt) 308 | return tgt 309 | 310 | def forward_pre(self, tgt, memory, 311 | tgt_mask: Optional[Tensor] = None, 312 | memory_mask: Optional[Tensor] = None, 313 | tgt_key_padding_mask: Optional[Tensor] = None, 314 | memory_key_padding_mask: Optional[Tensor] = None, 315 | pos: Optional[Tensor] = None, 316 | query_pos: Optional[Tensor] = None): 317 | tgt2 = self.norm1(tgt) 318 | q = k = self.with_pos_embed(tgt2, query_pos) 319 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 320 | key_padding_mask=tgt_key_padding_mask)[0] 321 | tgt = tgt + self.dropout1(tgt2) 322 | tgt2 = self.norm2(tgt) 323 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 324 | key=self.with_pos_embed(memory, pos), 325 | value=memory, attn_mask=memory_mask, 326 | key_padding_mask=memory_key_padding_mask)[0] 327 | tgt = tgt + self.dropout2(tgt2) 328 | tgt2 = self.norm3(tgt) 329 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 330 | tgt = tgt + self.dropout3(tgt2) 331 | return tgt 332 | 333 | def forward(self, tgt, memory, 334 | tgt_mask: Optional[Tensor] = None, 335 | memory_mask: Optional[Tensor] = None, 336 | tgt_key_padding_mask: Optional[Tensor] = None, 337 | memory_key_padding_mask: Optional[Tensor] = None, 338 | pos: Optional[Tensor] = None, 339 | query_pos: Optional[Tensor] = None): 340 | if self.normalize_before: 341 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 342 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 343 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 344 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 345 | 346 | 347 | def _get_clones(module, N): 348 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 349 | class Model(nn.Module): 350 | def __init__(self,height=720,width=1280,stride=16,num_feature_levels=3,num_channels=[96,192,384],hidden_dim=128) -> None: 351 | super().__init__() 352 | self.backbone = build_backbone() 353 | self.transformer=Transformer(d_model=hidden_dim,nhead=8,num_encoder_layers=6,num_decoder_layers=6,dim_feedforward=2048,dropout=0.1,normalize_before=False,return_intermediate_dec=False) 354 | self.ot_loss = ot_similarity() 355 | 356 | input_proj_list=[] 357 | for i in range(num_feature_levels): 358 | input_proj_list.append(nn.Sequential( 359 | nn.Conv2d(num_channels[i],hidden_dim,kernel_size=1,stride=1,padding=0), 360 | nn.GroupNorm(32,hidden_dim), 361 | nn.GELU(), 362 | )) 363 | for j in range(num_feature_levels-1-i): 364 | input_proj_list[i]+=nn.Sequential( 365 | nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=2,padding=1), 366 | nn.GroupNorm(32,hidden_dim), 367 | nn.GELU(), 368 | ) 369 | self.input_fuse=nn.Sequential( 370 | nn.Conv2d(hidden_dim*num_feature_levels,hidden_dim,kernel_size=1,stride=1,padding=0), 371 | nn.GroupNorm(32,hidden_dim), 372 | nn.GELU(), 373 | ) 374 | self.input_proj=nn.ModuleList(input_proj_list) 375 | for proj in self.input_proj: 376 | nn.init.xavier_uniform_(proj[0].weight, gain=nn.init.calculate_gain('relu')) 377 | nn.init.constant_(proj[0].bias, 0) 378 | self.hidden_dim=hidden_dim 379 | self.pos_embedder=PositionEmbeddingSine(hidden_dim//2,normalize=True) 380 | def forward(self, x, ref_points): 381 | x1=x[:,0:3,:,:] 382 | x2=x[:,3:6,:,:] 383 | ref_point1=ref_points[:,0,...] 384 | ref_point2=ref_points[:,1,...] 385 | 386 | z1 = self.backbone(x1) 387 | z2 = self.backbone(x2) 388 | z1_list=[] 389 | z2_list=[] 390 | for i in range(len(z1)): 391 | z1_list.append(self.input_proj[i](z1[i])) 392 | z2_list.append(self.input_proj[i](z2[i])) 393 | z1=torch.cat(z1_list,dim=1) 394 | z1=self.input_fuse(z1) 395 | z2=torch.cat(z2_list,dim=1) 396 | z2=self.input_fuse(z2) 397 | query_embed1 = pos2posemb2d(ref_point1, num_pos_feats=self.hidden_dim//2).float() 398 | query_embed2 = pos2posemb2d(ref_point2, num_pos_feats=self.hidden_dim//2).float() 399 | pos_emd_1=self.pos_embedder(z1_list[0].permute(0,2,3,1)) 400 | pos_emd_2=self.pos_embedder(z2_list[0].permute(0,2,3,1)) 401 | mask1=mask2=torch.ones((z1.shape[0],z1.shape[2],z1.shape[3])).cuda().bool() 402 | print(z1.shape,pos_emd_1.shape,query_embed1.shape) 403 | features1=self.transformer(z1,mask1,query_embed1,pos_emd_1) 404 | features2=self.transformer(z2,mask2,query_embed2,pos_emd_2) 405 | features1=F.relu(features1) 406 | features2=F.relu(features2) 407 | return features1,features2 408 | def loss(self,features1,features2): 409 | loss=self.ot_loss(features1,features2) 410 | return loss 411 | def build_model(): 412 | model=Model() 413 | return model 414 | 415 | if __name__=="__main__": 416 | transformer=Transformer(64,64).cuda().half() 417 | x=torch.rand(1,256,64,64).cuda().half() 418 | ref_pts=torch.rand(1,300,2).cuda().half() 419 | y=transformer(x,ref_pts) 420 | print(y.shape) -------------------------------------------------------------------------------- /models/fuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MS_CAM(nn.Module): 6 | ''' 7 | 单特征 进行通道加权,作用类似SE模块 8 | ''' 9 | 10 | def __init__(self, channels=64, r=4): 11 | super(MS_CAM, self).__init__() 12 | inter_channels = int(channels // r) 13 | 14 | self.local_att = nn.Sequential( 15 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 16 | nn.BatchNorm2d(inter_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 19 | nn.BatchNorm2d(channels), 20 | ) 21 | 22 | self.global_att = nn.Sequential( 23 | nn.AdaptiveAvgPool2d(1), 24 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 25 | nn.BatchNorm2d(inter_channels), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 28 | nn.BatchNorm2d(channels), 29 | ) 30 | 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | xl = self.local_att(x) 35 | xg = self.global_att(x) 36 | xlg = xl + xg 37 | wei = self.sigmoid(xlg) 38 | return x * wei 39 | 40 | class FFN(nn.Module): 41 | def __init__(self, in_chans, mid_chans, out_chans): 42 | super().__init__() 43 | self.conv1 = nn.Conv2d(in_chans, mid_chans, 1) 44 | self.act1 = nn.GELU() 45 | self.conv2 = nn.Conv2d(mid_chans, out_chans, 1) 46 | self.act2 = nn.GELU() 47 | 48 | def forward(self, x): 49 | x = self.act1(self.conv1(x)) 50 | x = self.act2(self.conv2(x)) 51 | return x 52 | 53 | 54 | class Head(nn.Module): 55 | def __init__(self): 56 | super(Head, self).__init__() 57 | self.conv1 =nn.Sequential( 58 | nn.Conv2d(192,96,1), 59 | nn.BatchNorm2d(96), 60 | nn.GELU(), 61 | ) 62 | self.conv2 = nn.Sequential( 63 | nn.Conv2d(384, 96, 1), 64 | nn.BatchNorm2d(96), 65 | nn.GELU(), 66 | ) 67 | self.conv3 = nn.Sequential( 68 | nn.Conv2d(768, 96, 1), 69 | nn.BatchNorm2d(96), 70 | nn.GELU(), 71 | ) 72 | self.ms_cam=MS_CAM(96+96+96+2) 73 | self.out1 = nn.Sequential( 74 | nn.Conv2d(96+96+96+2, 256, 1), 75 | nn.BatchNorm2d(256), 76 | nn.GELU(), 77 | nn.Conv2d(256, 128, 3,dilation=2,padding=2), 78 | nn.GELU(), 79 | nn.Conv2d(128, 64, 3,dilation=2,padding=2), 80 | nn.GELU(), 81 | nn.Conv2d(64, 1, 1), 82 | nn.GELU(), 83 | ) 84 | self._initialize_weights() 85 | 86 | def _initialize_weights(self): 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | nn.init.normal_(m.weight, std=0.01) 90 | if m.bias is not None: 91 | nn.init.constant_(m.bias, 0) 92 | elif isinstance(m, nn.BatchNorm2d): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | def forward(self, z,c1,c2): 97 | # print(x.shape) 98 | # torch.Size([1, 96, 256, 256]) 99 | # torch.Size([1, 192, 128, 128]) 100 | # torch.Size([1, 384, 64, 64]) 101 | # torch.Size([1, 768, 32, 32]) 102 | z1,z2=z 103 | x1, x2 , x3 = z1 104 | y1, y2 , y3 = z2 105 | x1 = self.conv1(torch.cat([x1,y1],dim=1)) 106 | x2 = self.conv2(torch.cat([x2,y2],dim=1)) 107 | x3 = self.conv3(torch.cat([x3,y3],dim=1)) 108 | # x3 = self.conv3(x3) 109 | # x4 = self.conv4(x4) 110 | 111 | x3= F.interpolate(x3, scale_factor=4) 112 | x2 = F.interpolate(x2, scale_factor=2) 113 | 114 | z=torch.cat([x1,x2,x3,c1,c2],dim=1) 115 | z=self.ms_cam(z) 116 | 117 | out1 = self.out1(z) 118 | return out1 119 | 120 | 121 | def build_head(): 122 | return Head() 123 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from mmcv.cnn import get_model_complexity_info 5 | from torch import nn 6 | from torch.amp import autocast 7 | from torch.nn import functional as F 8 | 9 | from .backbone import build_backbone 10 | from .head import build_head as build_locating_head 11 | from .fuse import build_head as build_fuse_head 12 | from .local_bl import build_loss, DrawDenseMap 13 | class Model(nn.Module): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | self.backbone = build_backbone() 17 | self.locating_head = build_locating_head() 18 | self.fuse_head=build_fuse_head() 19 | self.pt2dmap = DrawDenseMap(1,0,2) 20 | # self.get_model_complexity(input_shape=(6, 1280, 720)) 21 | 22 | @autocast("cuda") 23 | def forward(self, x): 24 | x1=x[:,0:3,:,:] 25 | x2=x[:,3:6,:,:] 26 | z1 = self.backbone(x1) 27 | z2 = self.backbone(x2) 28 | 29 | default_counting_map = self.locating_head([z1,z2]) 30 | duplicate_counting_map = self.locating_head([z2,z1]) 31 | fuse_counting_map=self.fuse_head([z1,z2],default_counting_map,duplicate_counting_map) 32 | return { 33 | "default_counting_map": F.interpolate(default_counting_map,scale_factor=2), 34 | "duplicate_counting_map": F.interpolate(duplicate_counting_map,scale_factor=2), 35 | "fuse_counting_map":F.interpolate(fuse_counting_map,scale_factor=2) 36 | } 37 | 38 | def get_model_complexity(self, input_shape): 39 | flops, params = get_model_complexity_info(self, input_shape) 40 | return flops, params 41 | 42 | 43 | @torch.no_grad() 44 | def _map2points(self,predict_counting_map,kernel,threshold,loc_kernel_size,loc_padding): 45 | device=predict_counting_map.device 46 | max_m=torch.max(predict_counting_map) 47 | threshold=max(0.1,threshold*max_m) 48 | low_resolution_map=F.interpolate(F.relu(predict_counting_map),scale_factor=0.5) 49 | H,W=low_resolution_map.shape[-2],low_resolution_map.shape[-1] 50 | 51 | unfolded_map=F.unfold(low_resolution_map,kernel_size=loc_kernel_size,padding=loc_padding) 52 | unfolded_max_idx=unfolded_map.max(dim=1,keepdim=True)[1] 53 | unfolded_max_mask=(unfolded_max_idx==loc_kernel_size**2//2).reshape(1,1,H,W) 54 | 55 | predict_cnt=F.conv2d(low_resolution_map,kernel,padding=loc_padding) 56 | predict_filter=(predict_cnt>threshold).float() 57 | predict_filter=predict_filter*unfolded_max_mask 58 | predict_filter=predict_filter.detach().cpu().numpy().astype(bool).reshape(H,W) 59 | 60 | pred_coord_weight=F.normalize(unfolded_map,p=1,dim=1) 61 | 62 | coord_h=torch.arange(H).reshape(-1,1).repeat(1,W).to(device).float() 63 | coord_w=torch.arange(W).reshape(1,-1).repeat(H,1).to(device).float() 64 | coord_h=coord_h.unsqueeze(0).unsqueeze(0) 65 | coord_w=coord_w.unsqueeze(0).unsqueeze(0) 66 | unfolded_coord_h=F.unfold(coord_h,kernel_size=loc_kernel_size,padding=loc_padding) 67 | pred_coord_h=(unfolded_coord_h*pred_coord_weight).sum(dim=1,keepdim=True).reshape(H,W).detach().cpu().numpy() 68 | unfolded_coord_w=F.unfold(coord_w,kernel_size=loc_kernel_size,padding=loc_padding) 69 | pred_coord_w=(unfolded_coord_w*pred_coord_weight).sum(dim=1,keepdim=True).reshape(H,W).detach().cpu().numpy() 70 | coord_h=pred_coord_h[predict_filter].reshape(-1,1) 71 | coord_w=pred_coord_w[predict_filter].reshape(-1,1) 72 | coord=np.concatenate([coord_w,coord_h],axis=1) 73 | 74 | pred_points=[[4*coord_w+loc_kernel_size/2.,4*coord_h+loc_kernel_size/2.] for coord_w,coord_h in coord] 75 | return pred_points 76 | 77 | @torch.no_grad() 78 | def forward_points(self, x, threshold=0.8,loc_kernel_size=3): 79 | assert loc_kernel_size%2==1 80 | assert x.shape[0]==1 81 | out_dict=self.forward(x) 82 | 83 | loc_padding=loc_kernel_size//2 84 | kernel=torch.ones(1,1,loc_kernel_size,loc_kernel_size).to(x.device).float() 85 | default_counting_map=out_dict["default_counting_map"].detach().float() 86 | duplicate_counting_map=out_dict["duplicate_counting_map"].detach().float() 87 | fuse_counting_map=out_dict["fuse_counting_map"].detach().float() 88 | default_points=self._map2points(default_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 89 | duplicate_points=self._map2points(duplicate_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 90 | fuse_points=self._map2points(fuse_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 91 | out_dict["default_pts"]=torch.tensor(default_points).to(x.device).float() 92 | out_dict["duplicate_pts"]=torch.tensor(duplicate_points).to(x.device).float() 93 | out_dict["fuse_pts"]=torch.tensor(fuse_points).to(x.device).float() 94 | return out_dict 95 | 96 | def loss(self, out_dict, targets): 97 | map_loss=build_loss() 98 | loss_dict = {} 99 | device=out_dict["default_counting_map"].device 100 | gt_default_maps,gt_duplicate_maps,gt_fuse_maps=[],[],[] 101 | map_loss=map_loss.to(device) 102 | with torch.no_grad(): 103 | for idx in range(targets["gt_default_pts"].shape[0]): 104 | gt_default_map=self.pt2dmap(targets["gt_default_pts"][idx],targets["gt_default_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 105 | gt_duplicate_map=self.pt2dmap(targets["gt_duplicate_pts"][idx],targets["gt_duplicate_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 106 | gt_fuse_map=self.pt2dmap(targets["gt_fuse_pts"][idx],targets["gt_fuse_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 107 | gt_default_maps.append(gt_default_map) 108 | gt_duplicate_maps.append(gt_duplicate_map) 109 | gt_fuse_maps.append(gt_fuse_map) 110 | gt_default_maps=torch.stack(gt_default_maps,dim=0) 111 | gt_duplicate_maps=torch.stack(gt_duplicate_maps,dim=0) 112 | gt_fuse_maps=torch.stack(gt_fuse_maps,dim=0) 113 | 114 | loss_dict["default"] = map_loss(out_dict["default_counting_map"],gt_default_maps) 115 | loss_dict["duplicate"] = map_loss(out_dict["duplicate_counting_map"],gt_duplicate_maps) 116 | loss_dict["fuse"] = map_loss(out_dict["fuse_counting_map"],gt_fuse_maps) 117 | loss_dict["all"] = loss_dict["default"] + loss_dict["duplicate"] + loss_dict["fuse"] 118 | return loss_dict 119 | 120 | def build_model(): 121 | return Model() 122 | -------------------------------------------------------------------------------- /models/roi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | for i in range(num_feature_levels): 24 | input_proj_list.append(nn.Sequential( 25 | nn.Conv2d(num_channels[i], hidden_dim, 26 | kernel_size=1, stride=1, padding=0), 27 | nn.GELU(), 28 | )) 29 | for j in range(num_feature_levels-1-i): 30 | input_proj_list[i] += nn.Sequential( 31 | nn.Conv2d(hidden_dim, hidden_dim, 32 | kernel_size=3, stride=2, padding=1), 33 | nn.GELU(), 34 | ) 35 | self.input_fuse = nn.Sequential( 36 | nn.Conv2d(hidden_dim*num_feature_levels, hidden_dim * 37 | num_feature_levels, kernel_size=1, stride=1, padding=0), 38 | nn.ReLU(), 39 | ) 40 | self.fc=nn.Sequential( 41 | nn.Linear(hidden_dim*num_feature_levels*8*8,hidden_dim*num_feature_levels), 42 | nn.ReLU(), 43 | nn.Linear(hidden_dim*num_feature_levels,hidden_dim), 44 | nn.ReLU(), 45 | nn.Linear(hidden_dim,hidden_dim*num_feature_levels), 46 | nn.ReLU(), 47 | ) 48 | self.input_proj = nn.ModuleList(input_proj_list) 49 | self.ot_loss = GML() 50 | 51 | def forward(self, input): 52 | x = input["image_pair"] 53 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 54 | x1 = x[:, 0:3, :, :] 55 | x2 = x[:, 3:6, :, :] 56 | ref_point1 = ref_points[:, 0, ...] 57 | ref_point2 = ref_points[:, 1, ...] 58 | z2 = self.backbone(x2) 59 | z2_list = [] 60 | for i in range(len(z2)): 61 | z2_list.append(self.input_proj[i](z2[i])) 62 | z2 = torch.cat(z2_list, dim=1) 63 | z2 = self.input_fuse(z2) 64 | features_2 = [] 65 | for pt in ref_point2[0]: 66 | features_2.append(self.get_feature(z2, pt)) 67 | z1 = self.backbone(x1) 68 | z1_list = [] 69 | for i in range(len(z1)): 70 | z1_list.append(self.input_proj[i](z1[i])) 71 | 72 | z1 = torch.cat(z1_list, dim=1) 73 | z1 = self.input_fuse(z1) 74 | 75 | features_1 = [] 76 | 77 | for pt in ref_point1[0]: 78 | features_1.append(self.get_feature(z1, pt)) 79 | 80 | features_1 = torch.stack(features_1, dim=1).flatten(2) 81 | features_2 = torch.stack(features_2, dim=1).flatten(2) 82 | return features_1, features_2 83 | 84 | def get_feature(self, z, pt, window_size=8): 85 | h, w = z.shape[-2], z.shape[-1] 86 | x_min = pt[0]*w-window_size//2 87 | x_max = pt[0]*w+window_size//2 88 | y_min = pt[1]*h-window_size//2 89 | y_max = pt[1]*h+window_size//2 90 | x_min, x_max, y_min, y_max = int(x_min), int( 91 | x_max), int(y_min), int(y_max) 92 | z = z[:, :, y_min:y_max, x_min:x_max] 93 | x_pad_left = 0 94 | x_pad_right = window_size-z.shape[-1] 95 | y_pad_top = 0 96 | y_pad_bottom = window_size-z.shape[-2] 97 | z = F.pad(z, (x_pad_left, x_pad_right, y_pad_top, y_pad_bottom)) 98 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 99 | # z = z + pos_emb 100 | # z=F.adaptive_avg_pool2d(z,(1,1)) 101 | # z=z.squeeze(3).squeeze(2) 102 | z = z.flatten(2).flatten(1) 103 | return z 104 | 105 | def loss(self, features1, features2): 106 | loss = self.ot_loss(features1, features2) 107 | loss_dict = {} 108 | loss_dict["all"] = loss 109 | return loss_dict 110 | def pos2posemb2d(self, pos, num_pos_feats=128, temperature=1000): 111 | scale = 2 * math.pi 112 | pos = pos * scale 113 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 114 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 115 | pos_x = pos[..., 0, None] / dim_t 116 | pos_y = pos[..., 1, None] / dim_t 117 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 118 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 119 | posemb = torch.cat((pos_y, pos_x), dim=-1) 120 | return posemb 121 | 122 | 123 | def build_model(): 124 | model = Model() 125 | return model 126 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | import math 8 | 9 | 10 | 11 | def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): 12 | scale = 2 * math.pi 13 | pos = pos * scale 14 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 15 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 16 | pos_x = pos[..., 0, None] / dim_t 17 | pos_y = pos[..., 1, None] / dim_t 18 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 19 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 20 | posemb = torch.cat((pos_y, pos_x), dim=-1) 21 | return posemb 22 | 23 | class Transformer(nn.Module): 24 | 25 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 26 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 27 | activation="relu", normalize_before=False, 28 | return_intermediate_dec=False): 29 | super().__init__() 30 | 31 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 34 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 35 | 36 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 37 | dropout, activation, normalize_before) 38 | decoder_norm = nn.LayerNorm(d_model) 39 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 40 | return_intermediate=return_intermediate_dec) 41 | 42 | self._reset_parameters() 43 | 44 | self.d_model = d_model 45 | self.nhead = nhead 46 | 47 | def _reset_parameters(self): 48 | for p in self.parameters(): 49 | if p.dim() > 1: 50 | nn.init.xavier_uniform_(p) 51 | 52 | def forward(self, src, mask, query_embed, pos_embed): 53 | # flatten NxCxHxW to HWxNxC 54 | bs, c, h, w = src.shape 55 | src = src.flatten(2).permute(2, 0, 1) 56 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 57 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 58 | mask = mask.flatten(1) 59 | 60 | tgt = torch.zeros_like(query_embed) 61 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 62 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 63 | pos=pos_embed, query_pos=query_embed) 64 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 65 | 66 | 67 | class TransformerEncoder(nn.Module): 68 | 69 | def __init__(self, encoder_layer, num_layers, norm=None): 70 | super().__init__() 71 | self.layers = _get_clones(encoder_layer, num_layers) 72 | self.num_layers = num_layers 73 | self.norm = norm 74 | 75 | def forward(self, src, 76 | mask: Optional[Tensor] = None, 77 | src_key_padding_mask: Optional[Tensor] = None, 78 | pos: Optional[Tensor] = None): 79 | output = src 80 | 81 | for layer in self.layers: 82 | output = layer(output, src_mask=mask, 83 | src_key_padding_mask=src_key_padding_mask, pos=pos) 84 | 85 | if self.norm is not None: 86 | output = self.norm(output) 87 | 88 | return output 89 | 90 | 91 | class TransformerDecoder(nn.Module): 92 | 93 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 94 | super().__init__() 95 | self.layers = _get_clones(decoder_layer, num_layers) 96 | self.num_layers = num_layers 97 | self.norm = norm 98 | self.return_intermediate = return_intermediate 99 | 100 | def forward(self, tgt, memory, 101 | tgt_mask: Optional[Tensor] = None, 102 | memory_mask: Optional[Tensor] = None, 103 | tgt_key_padding_mask: Optional[Tensor] = None, 104 | memory_key_padding_mask: Optional[Tensor] = None, 105 | pos: Optional[Tensor] = None, 106 | query_pos: Optional[Tensor] = None): 107 | output = tgt 108 | 109 | intermediate = [] 110 | 111 | for layer in self.layers: 112 | output = layer(output, memory, tgt_mask=tgt_mask, 113 | memory_mask=memory_mask, 114 | tgt_key_padding_mask=tgt_key_padding_mask, 115 | memory_key_padding_mask=memory_key_padding_mask, 116 | pos=pos, query_pos=query_pos) 117 | if self.return_intermediate: 118 | intermediate.append(self.norm(output)) 119 | 120 | if self.norm is not None: 121 | output = self.norm(output) 122 | if self.return_intermediate: 123 | intermediate.pop() 124 | intermediate.append(output) 125 | 126 | if self.return_intermediate: 127 | return torch.stack(intermediate) 128 | 129 | return output.unsqueeze(0) 130 | 131 | 132 | class TransformerEncoderLayer(nn.Module): 133 | 134 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 135 | activation="relu", normalize_before=False): 136 | super().__init__() 137 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 138 | # Implementation of Feedforward model 139 | self.linear1 = nn.Linear(d_model, dim_feedforward) 140 | self.dropout = nn.Dropout(dropout) 141 | self.linear2 = nn.Linear(dim_feedforward, d_model) 142 | 143 | self.norm1 = nn.LayerNorm(d_model) 144 | self.norm2 = nn.LayerNorm(d_model) 145 | self.dropout1 = nn.Dropout(dropout) 146 | self.dropout2 = nn.Dropout(dropout) 147 | 148 | self.activation = _get_activation_fn(activation) 149 | self.normalize_before = normalize_before 150 | 151 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 152 | return tensor if pos is None else tensor + pos 153 | 154 | def forward_post(self, 155 | src, 156 | src_mask: Optional[Tensor] = None, 157 | src_key_padding_mask: Optional[Tensor] = None, 158 | pos: Optional[Tensor] = None): 159 | q = k = self.with_pos_embed(src, pos) 160 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 161 | key_padding_mask=src_key_padding_mask)[0] 162 | src = src + self.dropout1(src2) 163 | src = self.norm1(src) 164 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 165 | src = src + self.dropout2(src2) 166 | src = self.norm2(src) 167 | return src 168 | 169 | def forward_pre(self, src, 170 | src_mask: Optional[Tensor] = None, 171 | src_key_padding_mask: Optional[Tensor] = None, 172 | pos: Optional[Tensor] = None): 173 | src2 = self.norm1(src) 174 | q = k = self.with_pos_embed(src2, pos) 175 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 176 | key_padding_mask=src_key_padding_mask)[0] 177 | src = src + self.dropout1(src2) 178 | src2 = self.norm2(src) 179 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 180 | src = src + self.dropout2(src2) 181 | return src 182 | 183 | def forward(self, src, 184 | src_mask: Optional[Tensor] = None, 185 | src_key_padding_mask: Optional[Tensor] = None, 186 | pos: Optional[Tensor] = None): 187 | if self.normalize_before: 188 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 189 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 190 | 191 | 192 | class TransformerDecoderLayer(nn.Module): 193 | 194 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 195 | activation="relu", normalize_before=False): 196 | super().__init__() 197 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 198 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 199 | # Implementation of Feedforward model 200 | self.linear1 = nn.Linear(d_model, dim_feedforward) 201 | self.dropout = nn.Dropout(dropout) 202 | self.linear2 = nn.Linear(dim_feedforward, d_model) 203 | 204 | self.norm1 = nn.LayerNorm(d_model) 205 | self.norm2 = nn.LayerNorm(d_model) 206 | self.norm3 = nn.LayerNorm(d_model) 207 | self.dropout1 = nn.Dropout(dropout) 208 | self.dropout2 = nn.Dropout(dropout) 209 | self.dropout3 = nn.Dropout(dropout) 210 | 211 | self.activation = _get_activation_fn(activation) 212 | self.normalize_before = normalize_before 213 | 214 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 215 | return tensor if pos is None else tensor + pos 216 | 217 | def forward_post(self, tgt, memory, 218 | tgt_mask: Optional[Tensor] = None, 219 | memory_mask: Optional[Tensor] = None, 220 | tgt_key_padding_mask: Optional[Tensor] = None, 221 | memory_key_padding_mask: Optional[Tensor] = None, 222 | pos: Optional[Tensor] = None, 223 | query_pos: Optional[Tensor] = None): 224 | q = k = self.with_pos_embed(tgt, query_pos) 225 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 226 | key_padding_mask=tgt_key_padding_mask)[0] 227 | tgt = tgt + self.dropout1(tgt2) 228 | tgt = self.norm1(tgt) 229 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 230 | key=self.with_pos_embed(memory, pos), 231 | value=memory, attn_mask=memory_mask, 232 | key_padding_mask=memory_key_padding_mask)[0] 233 | tgt = tgt + self.dropout2(tgt2) 234 | tgt = self.norm2(tgt) 235 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 236 | tgt = tgt + self.dropout3(tgt2) 237 | tgt = self.norm3(tgt) 238 | return tgt 239 | 240 | def forward_pre(self, tgt, memory, 241 | tgt_mask: Optional[Tensor] = None, 242 | memory_mask: Optional[Tensor] = None, 243 | tgt_key_padding_mask: Optional[Tensor] = None, 244 | memory_key_padding_mask: Optional[Tensor] = None, 245 | pos: Optional[Tensor] = None, 246 | query_pos: Optional[Tensor] = None): 247 | tgt2 = self.norm1(tgt) 248 | q = k = self.with_pos_embed(tgt2, query_pos) 249 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 250 | key_padding_mask=tgt_key_padding_mask)[0] 251 | tgt = tgt + self.dropout1(tgt2) 252 | tgt2 = self.norm2(tgt) 253 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 254 | key=self.with_pos_embed(memory, pos), 255 | value=memory, attn_mask=memory_mask, 256 | key_padding_mask=memory_key_padding_mask)[0] 257 | tgt = tgt + self.dropout2(tgt2) 258 | tgt2 = self.norm3(tgt) 259 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 260 | tgt = tgt + self.dropout3(tgt2) 261 | return tgt 262 | 263 | def forward(self, tgt, memory, 264 | tgt_mask: Optional[Tensor] = None, 265 | memory_mask: Optional[Tensor] = None, 266 | tgt_key_padding_mask: Optional[Tensor] = None, 267 | memory_key_padding_mask: Optional[Tensor] = None, 268 | pos: Optional[Tensor] = None, 269 | query_pos: Optional[Tensor] = None): 270 | if self.normalize_before: 271 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 272 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 273 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 274 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 275 | 276 | 277 | def _get_clones(module, N): 278 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 279 | 280 | 281 | def build_transformer(args): 282 | return Transformer( 283 | d_model=args.hidden_dim, 284 | dropout=args.dropout, 285 | nhead=args.nheads, 286 | dim_feedforward=args.dim_feedforward, 287 | num_encoder_layers=args.enc_layers, 288 | num_decoder_layers=args.dec_layers, 289 | normalize_before=args.pre_norm, 290 | return_intermediate_dec=True, 291 | ) 292 | 293 | 294 | def _get_activation_fn(activation): 295 | """Return an activation function given a string""" 296 | if activation == "relu": 297 | return F.relu 298 | if activation == "gelu": 299 | return F.gelu 300 | if activation == "glu": 301 | return F.glu 302 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /models/tri_cropper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | 24 | self.ot_loss = GML() 25 | 26 | def forward(self, input): 27 | x = input["image_pair"] 28 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 29 | ind_points0=input["independ_pts0"][:,:input["independ_num0"],...] 30 | ind_points1=input["independ_pts1"][:,:input["independ_num1"],...] 31 | x1 = x[:, 0:3, :, :] 32 | x2 = x[:, 3:6, :, :] 33 | ref_point1 = ref_points[:, 0, ...] 34 | ref_point2 = ref_points[:, 1, ...] 35 | 36 | z1_lists=[] 37 | z2_lists=[] 38 | y1_lists=[] 39 | y2_lists=[] 40 | for b in range(x1.shape[0]): 41 | z1_list=[] 42 | z2_list=[] 43 | y1_list=[] 44 | y2_list=[] 45 | for pt1,pt2 in zip(ref_point1[b],ref_point2[b]): 46 | z1=self.get_crops(x1[b].unsqueeze(0),pt1) 47 | z2=self.get_crops(x2[b].unsqueeze(0),pt2) 48 | z1=F.interpolate(z1,(224,224)) 49 | z2=F.interpolate(z2,(224,224)) 50 | z1_list.append(z1) 51 | z2_list.append(z2) 52 | z1=torch.cat(z1_list,dim=0) 53 | z2=torch.cat(z2_list,dim=0) 54 | z1=self.backbone(z1)[0].flatten(2).flatten(1) 55 | z2=self.backbone(z2)[0].flatten(2).flatten(1) 56 | z1_lists.append(z1) 57 | z2_lists.append(z2) 58 | for pt in ind_points0[b]: 59 | y1=self.get_crops(x1[b].unsqueeze(0),pt) 60 | y1=F.interpolate(y1,(224,224)) 61 | y1_list.append(y1) 62 | for pt in ind_points1[b]: 63 | y2=self.get_crops(x2[b].unsqueeze(0),pt) 64 | y2=F.interpolate(y2,(224,224)) 65 | y2_list.append(y2) 66 | y1=torch.cat(y1_list,dim=0) 67 | y2=torch.cat(y2_list,dim=0) 68 | y1=self.backbone(y1)[0].flatten(2).flatten(1) 69 | y2=self.backbone(y2)[0].flatten(2).flatten(1) 70 | y1_lists.append(y1) 71 | y2_lists.append(y2) 72 | y1=torch.stack(y1_lists,dim=0) 73 | y2=torch.stack(y2_lists,dim=0) 74 | z1=torch.stack(z1_lists,dim=0) 75 | z2=torch.stack(z2_lists,dim=0) 76 | z1=F.normalize(z1,dim=-1) 77 | z2=F.normalize(z2,dim=-1) 78 | y1=F.normalize(y1,dim=-1) 79 | y2=F.normalize(y2,dim=-1) 80 | return z1,z2,y1,y2 81 | def forward_single_image(self,x,pts,absolute=False): 82 | z_lists=[] 83 | max_batch=20 84 | for b in range(x.shape[0]): 85 | z_list=[] 86 | z_lists.append([]) 87 | for pt in pts[b]: 88 | z=self.get_crops(x[b].unsqueeze(0),pt,absolute=absolute) 89 | z=F.interpolate(z,(224,224)) 90 | z_list.append(z) 91 | z_list=[z_list[i:i+max_batch] for i in range(0,len(z_list),max_batch)] 92 | 93 | for z in z_list: 94 | 95 | z=torch.cat(z,dim=0) 96 | 97 | z=self.backbone(z)[0].flatten(2).flatten(1) 98 | z_lists[-1].append(z) 99 | z_lists[-1]=torch.cat(z_lists[-1],dim=0) 100 | z=torch.stack(z_lists,dim=0) 101 | z=F.normalize(z,dim=-1) 102 | return z 103 | def get_crops(self, z, pt, window_size=[32,32,32,64],absolute=False): 104 | h, w = z.shape[-2], z.shape[-1] 105 | if absolute: 106 | x_min = pt[0]-window_size[0] 107 | x_max = pt[0]+window_size[1] 108 | y_min = pt[1]-window_size[2] 109 | y_max = pt[1]+window_size[3] 110 | else: 111 | 112 | x_min = pt[0]*w-window_size[0] 113 | x_max = pt[0]*w+window_size[1] 114 | y_min = pt[1]*h-window_size[2] 115 | y_max = pt[1]*h+window_size[3] 116 | x_min, x_max, y_min, y_max = int(x_min), int( 117 | x_max), int(y_min), int(y_max) 118 | x_min = max(0, x_min) 119 | x_max = min(w, x_max) 120 | y_min = max(0, y_min) 121 | y_max = min(h, y_max) 122 | z = z[..., y_min:y_max, x_min:x_max] 123 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 124 | # z = z + pos_emb 125 | # z=F.adaptive_avg_pool2d(z,(1,1)) 126 | # z=z.squeeze(3).squeeze(2) 127 | return z 128 | 129 | def loss(self, z1,z2,y1,y2): 130 | loss_dict = self.ot_loss([z1,y1], [z2,y2]) 131 | # loss = self.ot_loss(z1, z2) 132 | 133 | loss_dict["all"] = loss_dict["scon_cost"]+loss_dict["hinge_cost"]*0.1 134 | return loss_dict 135 | 136 | 137 | def build_model(): 138 | model = Model() 139 | return model 140 | -------------------------------------------------------------------------------- /models/tri_sim_ot_b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from geomloss import SamplesLoss 5 | from scipy.optimize import linear_sum_assignment 6 | from sklearn.manifold import TSNE 7 | import scienceplots 8 | def similarity_cost(x1, x2, gamma=10): 9 | ''' 10 | :param x1: [B N D] 11 | :param x2: [B M D] 12 | :return: [B N M] 13 | ''' 14 | N = x1.shape[1] 15 | x1 = F.normalize(x1, dim=-1) 16 | x2 = F.normalize(x2, dim=-1) 17 | sim = torch.bmm(x1, x2.transpose(1, 2)) * gamma 18 | sim = sim.exp() 19 | sim_c = sim.sum(dim=1, keepdim=True) 20 | sim_r = sim.sum(dim=2, keepdim=True) 21 | sim = sim / (sim_c + sim_r - sim) 22 | return 1 - sim 23 | 24 | def get_group_ones(N, b, B, device): 25 | alpha = [] 26 | for i in range(B): 27 | if i == b: 28 | alpha.append(torch.ones(1, N[i], device=device).float()) 29 | else: 30 | alpha.append(torch.zeros(1, N[i], device=device).float()) 31 | 32 | alpha = torch.cat(alpha, dim=1) 33 | return alpha 34 | 35 | class GML(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.ot = SamplesLoss(backend="tensorized", cost=similarity_cost, debias=False,diameter=3) 39 | self.margin=0.1 40 | def forward(self, x_pre, x_cur): 41 | ''' 42 | :param x1: B tensor of size [N_b D] 43 | :param x2: B tensor of size [N_b D] 44 | :param x3: B tensor of size [N1_b D] negative exapmles 45 | :param x4: B tensor of size [M1_b D] negative exapmles 46 | :return: ot loss 47 | ''' 48 | x1, x3 = x_pre 49 | x2, x4 = x_cur 50 | B = len(x1) # assert B == 1 51 | N = [x.shape[0] for x in x1] 52 | M = [x.shape[0] for x in x2] 53 | N1 = [x.shape[0] for x in x3] 54 | M1 = [x.shape[0] for x in x4] 55 | 56 | device = x1[0].device 57 | ot_loss = [] 58 | 59 | for b in range(B): 60 | assert N[b] == M[b] 61 | alpha = torch.cat((torch.ones(1, N[b], device=device), torch.zeros(1, N1[b], device=device)), dim=1) 62 | beta = torch.cat((torch.ones(1, N[b], device=device), torch.zeros(1, M1[b], device=device)), dim=1) 63 | f1 = torch.cat((x1[b], x3[b]), dim=0).unsqueeze(0) 64 | f2 = torch.cat((x2[b], x4[b]), dim=0).unsqueeze(0) 65 | loss = self.ot(alpha, f1, beta, f2) 66 | ot_loss.append(loss) 67 | loss=0 68 | loss+=torch.relu(self.margin-torch.bmm(x3,x4.transpose(1,2))).sum() 69 | loss_dict={ 70 | "hinge_cost":loss.unsqueeze(0), 71 | "scon_cost":sum(ot_loss) / len(ot_loss) 72 | } 73 | return loss_dict 74 | 75 | 76 | if __name__ == "__main__": 77 | # x1 = torch.ones(2, 2) 78 | # x2 = torch.ones(2, 2) 79 | from matplotlib import pyplot as plt 80 | Dim = 512 81 | 82 | x1 = torch.randn(1,25, Dim) 83 | x2 = torch.randn(1,25, Dim) 84 | x1.requires_grad = True 85 | x2.requires_grad = True 86 | 87 | x3 = torch.randn(1,6, Dim) 88 | x4 = torch.randn(1,9, Dim) 89 | x3.requires_grad = True 90 | x4.requires_grad = True 91 | 92 | # x1 = [torch.eye(10)] 93 | # x2 = [torch.eye(10)] 94 | 95 | lr = 0.5 96 | loss = GML() 97 | for i in range(50): 98 | x1=F.normalize(x1,dim=-1) 99 | x2=F.normalize(x2,dim=-1) 100 | x3=F.normalize(x3,dim=-1) 101 | x4=F.normalize(x4,dim=-1) 102 | 103 | l = loss([x1, x3], [x2, x4]) 104 | t=l["scon_cost"]+0.2*l["hinge_cost"] 105 | if (i+1) % 1 == 0 or i == 0: 106 | 107 | s1=torch.bmm(x1,x2.transpose(1,2)) 108 | s2= torch.bmm(x1,x4.transpose(1,2)) 109 | s3 = torch.bmm(x3,x2.transpose(1,2)) 110 | s4 = torch.bmm(x3,x4.transpose(1,2)) 111 | s1_s2=torch.cat((s1,s2),dim=2) 112 | s3_s4=torch.cat((s3,s4),dim=2) 113 | img=torch.cat((s1_s2,s3_s4),dim=1) 114 | 115 | [g1, g2, g3, g4] = torch.autograd.grad(t, [x1, x2, x3, x4]) 116 | x1.data -= lr * g1 117 | x1.data = F.relu(x1.data) 118 | x2.data -= lr * g2 119 | x2.data = F.relu(x2.data) 120 | x3.data -= lr * g3 121 | x3.data = F.relu(x3.data) 122 | x4.data -= lr * g4 123 | x4.data = F.relu(x4.data) 124 | 125 | pos_sim_cost=1- similarity_cost(x1, x2) 126 | neg_sim_cost=1- similarity_cost(x1, x4) 127 | all_cost=1- similarity_cost(torch.cat((x1,x3),dim=1), torch.cat((x2,x4),dim=1)) 128 | print("pos_sim_cost") 129 | print(pos_sim_cost.cpu().detach().numpy()[0]) 130 | print("neg_sim_cost") 131 | print(neg_sim_cost.cpu().detach().numpy()[0]) 132 | print("all_cost") 133 | print(all_cost.cpu().detach().numpy()[0]) 134 | x1=F.normalize(x1,dim=-1) 135 | x2=F.normalize(x2,dim=-1) 136 | x3=F.normalize(x3,dim=-1) 137 | x4=F.normalize(x4,dim=-1) 138 | match_matrix=torch.bmm(torch.cat((x1,x3),dim=1),torch.cat((x2,x4),dim=1).transpose(1,2)) 139 | 140 | print("match_matrix") 141 | print(match_matrix.cpu().detach().numpy()[0]) 142 | match=linear_sum_assignment(1-match_matrix.cpu().detach().numpy()[0]) 143 | print("match") 144 | print(match) 145 | print("x1") 146 | print(x1.cpu().detach().numpy()) 147 | print("x2") 148 | print(x2.cpu().detach().numpy()) 149 | print("x3") 150 | print(x3.cpu().detach().numpy()) 151 | print("x4") 152 | print(x4.cpu().detach().numpy()) -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer_builder import optimizer_builder, optimizer_finetune_builder 2 | from .scheduler_builder import scheduler_builder -------------------------------------------------------------------------------- /optimizer/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def optimizer_builder(args,model_without_ddp): 4 | params = [ 5 | { 6 | "params": 7 | [p for n, p in model_without_ddp.named_parameters() 8 | if p.requires_grad], 9 | "lr": args.lr, 10 | }, 11 | ] 12 | if args.type.lower()=="sgd": 13 | return torch.optim.SGD(params, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 14 | elif args.type.lower()=="adam": 15 | return torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) 16 | elif args.type.lower()=="adamw": 17 | return torch.optim.AdamW(params, args.lr, weight_decay=args.weight_decay) 18 | 19 | def optimizer_finetune_builder(args,model_without_ddp): 20 | param_backbone = [p for p in model_without_ddp.backbone.parameters() if p.requires_grad] + \ 21 | [p for p in model_without_ddp.decoder_layers.decoder.parameters() if p.requires_grad] 22 | param_predict = [p for p in model_without_ddp.decoder_layers.output_layer.parameters() if p.requires_grad] 23 | params = [ 24 | { 25 | "params": param_backbone, 26 | "lr": args.lr * 0.1 , 27 | }, 28 | { 29 | "params": param_predict, 30 | "lr": args.lr, 31 | } 32 | ] 33 | if args.type.lower()=="sgd": 34 | return torch.optim.SGD(params, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 35 | elif args.type.lower()=="adam": 36 | return torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) 37 | elif args.type.lower()=="adamw": 38 | return torch.optim.AdamW(params, args.lr, weight_decay=args.weight_decay) -------------------------------------------------------------------------------- /optimizer/scheduler_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def scheduler_builder(args,optimizer): 4 | print(optimizer) 5 | if args.type=="step": 6 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, args.gamma) 7 | elif args.type=="cosine": 8 | return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.T_max, args.eta_min) 9 | elif args.type=="multi_step": 10 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, args.gamma) 11 | elif args.type=="consine_warm_restart": 12 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.T_max, args.eta_min) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Video Individual Counting 2 | 3 | Official PyTorch implementation of "Weakly Supervised Video Individual Counting" as presented at CVPR 2024. 4 | 5 | 📄 [Read the Paper](https://openaccess.thecvf.com/content/CVPR2024/html/Liu_Weakly_Supervised_Video_Individual_Counting_CVPR_2024_paper.html) 6 | 7 | **Authors:** Xinyan Liu, Guorong Li, Yuankai Qi, Ziheng Yan, Zhenjun Han, Anton van den Hengel, Ming-Hsuan Yang, Qingming Huang 8 | 9 | ## Overview 10 | 11 | The Video Individual Counting (VIC) task focuses on predicting the count of unique individuals in videos. Traditional methods, which rely on costly individual trajectory annotations, are impractical for large-scale applications. This work introduces a novel approach to VIC under a weakly supervised framework, utilizing less restrictive inflow and outflow annotations. We propose a baseline method employing weakly supervised contrastive learning for group-level matching, enhanced by a custom soft contrastive loss, facilitating the distinction between different crowd dynamics. We also contribute two augmented datasets, SenseCrowd and CroHD, and introduce a new dataset, UAVVIC, to foster further research in this area. Our method demonstrates superior performance compared to fully supervised counterparts, making a strong case for its practical applicability. 12 | 13 | ## Inference Pipeline 14 | 15 |  16 | 17 | The CGNet architecture includes: 18 | - **Frame-level Crowd Locator**: Detects pedestrian coordinates. 19 | - **Encoder**: Generates unique representations for each detected individual. 20 | - **Memory-based Individual Count Predictor (MCP)**: Estimates inflow counts and maintains a memory of individual templates. 21 | 22 | The Weakly Supervised Representation Learning (WSRL) method utilizes both inflow and outflow labels to refine the encoder through a novel Group-Level Matching Loss (GML), integrating soft contrastive and hinge losses to optimize performance. 23 | 24 | ## Demo 25 | 26 | Our model processes video inputs to predict individual counts, operating over 3-second intervals. 27 | 28 |
29 |
30 |
31 |
32 |
33 |