├── .gitignore ├── configs └── crowd_sense.json ├── dist_train.sh ├── eval.py ├── inference.py ├── misc ├── saver_builder.py └── tools.py ├── models ├── backbone.py ├── cropper.py ├── extractor.py ├── fuse.py ├── model.py ├── roi.py ├── transformer.py ├── tri_cropper.py └── tri_sim_ot_b.py ├── optimizer ├── __init__.py ├── optimizer_builder.py └── scheduler_builder.py ├── readme.md ├── requirements.txt ├── result └── video_results_test.json ├── statics └── imgs │ ├── c3.gif │ ├── c4.gif │ ├── p2.gif │ ├── p3.gif │ ├── pipeline.png │ ├── sim.mp4 │ └── tinywow_sim_58930123.gif ├── train.py ├── tri_dataset.py ├── tri_eingine.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python gitignore 2 | *.pyc 3 | outputs/* 4 | datasets/* -------------------------------------------------------------------------------- /configs/crowd_sense.json: -------------------------------------------------------------------------------- 1 | { 2 | "Dataset": { 3 | "train": { 4 | "root": "datasets/sensecrowd/train", 5 | "ann_dir": "datasets/sensecrowd/new_annotations", 6 | "size_divisor": 32, 7 | "batch_size": 1, 8 | "num_workers": 8, 9 | "shuffle": true, 10 | "drop_last": true, 11 | "cache_mode": false, 12 | "max_len": 3000 13 | }, 14 | "val": { 15 | "root": "datasets/sensecrowd/val", 16 | "ann_dir": "datasets/sensecrowd/new_annotations", 17 | "size_divisor": 32, 18 | "batch_size": 1, 19 | "num_workers": 8, 20 | "shuffle": false, 21 | "drop_last": false, 22 | "cache_mode": false, 23 | "max_len": 3000 24 | }, 25 | "test": { 26 | "root": "datasets/sensecrowd/test", 27 | "ann_dir": "datasets/sensecrowd/new_annotations", 28 | "size_divisor": 32, 29 | "batch_size": 1, 30 | "num_workers": 8, 31 | "shuffle": false, 32 | "drop_last": false, 33 | "cache_mode": false, 34 | "max_len": 3000 35 | } 36 | }, 37 | "Optimizer": { 38 | "type": "AdamW", 39 | "lr": 0.0001, 40 | "betas": [ 41 | 0.9, 42 | 0.999 43 | ], 44 | "eps": 1e-08, 45 | "weight_decay": 0.000001 46 | }, 47 | "Scheduler": { 48 | "type": "cosine", 49 | "T_max": 200, 50 | "eta_min":0.000000001, 51 | "ema":false, 52 | "ema_annel_strategy": "cos", 53 | "ema_annel_epochs":10, 54 | "ema_lr":0.000000001, 55 | "ema_weight":0.9, 56 | "ema_start_epoch":90 57 | }, 58 | "Saver": { 59 | "save_dir": "./outputs", 60 | "save_interval": 1, 61 | "save_start_epoch": 0, 62 | "save_num_per_epoch": 2, 63 | "max_save_num": 20, 64 | "save_best": true, 65 | "metric":"all_match_acc", 66 | "reverse": true 67 | }, 68 | "Logger": { 69 | "delimiter": "\t", 70 | "print_freq": 25, 71 | "header": "" 72 | }, 73 | "Misc": { 74 | "epochs":201, 75 | "use_tensorboard": true, 76 | "tensorboard_dir": "./outputs", 77 | "clip_max_norm": 10, 78 | "val_freq":1 79 | }, 80 | "Drawer": { 81 | "draw_freq": 25, 82 | "output_dir": "./outputs", 83 | "draw_original": true, 84 | "draw_denseMap": true, 85 | "draw_output": true, 86 | "mean": [ 87 | 0.485, 88 | 0.456, 89 | 0.406 90 | ], 91 | "std": [ 92 | 0.229, 93 | 0.224, 94 | 0.225 95 | ] 96 | } 97 | } -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --master_port 29505 --nproc_per_node=$NUM_PROC train.py -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | with open("result/video_results_test.json","r") as f: 5 | video_results=json.load(f) 6 | 7 | anno_root="datasets/sensecrowd/new_annotations" 8 | 9 | 10 | gt_video_num_list=[] 11 | gt_video_len_list=[] 12 | pred_video_num_list=[] 13 | pred_matched_num_list=[] 14 | gt_matched_num_list=[] 15 | for video_name in video_results: 16 | video_len=0 17 | anno_path=os.path.join(anno_root,video_name+".txt") 18 | with open(anno_path,"r") as f: 19 | lines=f.readlines() 20 | all_ids=set() 21 | 22 | for line in lines: 23 | line=line.strip().split(" ") 24 | data=[float(x) for x in line[3:] if x!=""] 25 | if len(data)>0: 26 | data=np.array(data) 27 | data=np.reshape(data,(-1,7)) 28 | ids=data[:,6].reshape(-1,1) 29 | for id in ids: 30 | all_ids.add(int(id[0])) 31 | info=video_results[video_name] 32 | gt_video_num=len(all_ids) 33 | pred_video_num=info["video_num"] 34 | pred_video_num_list.append(pred_video_num) 35 | gt_video_num_list.append(gt_video_num) 36 | gt_video_len_list.append(info["frame_num"]) 37 | 38 | 39 | 40 | MAE=np.mean(np.abs(np.array(gt_video_num_list)-np.array(pred_video_num_list))) 41 | MSE=np.mean(np.square(np.array(gt_video_num_list)-np.array(pred_video_num_list))) 42 | WRAE=np.sum(np.abs(np.array(gt_video_num_list)-np.array(pred_video_num_list))*np.array(gt_video_len_list)/np.array(gt_video_num_list)/np.sum(gt_video_len_list)) 43 | RMSE=np.sqrt(MSE) 44 | 45 | print(f"MAE:{MAE:.2f}, MSE:{MSE:.2f}, WRAE:{WRAE*100:.2f}%, RMSE:{RMSE:.2f}") -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import shutil 5 | import time 6 | from asyncio.log import logger 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from easydict import EasyDict as edict 11 | from termcolor import cprint 12 | from torch.cuda.amp import GradScaler 13 | from torch.nn import SyncBatchNorm 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from misc import tools 19 | from misc.saver_builder import Saver 20 | from misc.tools import MetricLogger, is_main_process 21 | from models.tri_cropper import build_model 22 | from tri_dataset import build_video_dataset as build_dataset 23 | from tri_dataset import inverse_normalize 24 | # from eingine.densemap_trainer import evaluate_counting, train_one_epoch 25 | from tri_eingine import evaluate_similarity, train_one_epoch 26 | from models.tri_sim_ot_b import similarity_cost 27 | import numpy as np 28 | from scipy.optimize import linear_sum_assignment 29 | from tqdm import tqdm 30 | import cv2 31 | 32 | torch.backends.cudnn.enabled = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | def read_pts(path): 37 | with open(path, "r") as f: 38 | lines = f.readlines() 39 | pts = [] 40 | for line in lines: 41 | line = line.strip().split(",") 42 | pts.append([float(line[0]), float(line[1])]) 43 | pts = np.array(pts) 44 | return pts 45 | 46 | 47 | def module2model(module_state_dict): 48 | state_dict = {} 49 | for k, v in module_state_dict.items(): 50 | while k.startswith("module."): 51 | k = k[7:] 52 | # while apply ema model 53 | if k == "n_averaged": 54 | print(f"{k}:{v}") 55 | continue 56 | state_dict[k] = v 57 | return state_dict 58 | 59 | 60 | def main(pair_cfg, pair_ckpt): 61 | tools.init_distributed_mode(pair_cfg,) 62 | tools.set_randomseed(42 + tools.get_rank()) 63 | # initilize the model 64 | model = model_without_ddp = build_model() 65 | model.load_state_dict(module2model(torch.load(pair_ckpt)["model"])) 66 | model.cuda() 67 | if pair_cfg.distributed: 68 | sync_model = SyncBatchNorm.convert_sync_batchnorm(model) 69 | model = torch.nn.parallel.DistributedDataParallel( 70 | sync_model, device_ids=[pair_cfg.gpu], find_unused_parameters=False) 71 | model_without_ddp = model.module 72 | 73 | # build the dataset and dataloader 74 | dataset = build_dataset(pair_cfg.Dataset.test.root, 75 | pair_cfg.Dataset.test.ann_dir) 76 | sampler = DistributedSampler( 77 | dataset, shuffle=False) if pair_cfg.distributed else None 78 | loader = DataLoader(dataset, 79 | batch_size=pair_cfg.Dataset.val.batch_size, 80 | sampler=sampler, 81 | shuffle=False, 82 | num_workers=pair_cfg.Dataset.val.num_workers, 83 | pin_memory=True) 84 | model.eval() 85 | video_results = {} 86 | interval = 15 87 | ttl = 5 88 | max_mem=5 89 | threshold=0.4 90 | with torch.no_grad(): 91 | for imgs, labels in tqdm(loader): 92 | cnt_list = [] 93 | video_name = labels["video_name"][0] 94 | img_names = labels["img_names"] 95 | w, h = labels["w"][0], labels["h"][0] 96 | 97 | img_name0 = img_names[0][0] 98 | pos_path0 = os.path.join( 99 | "locater/results", video_name, img_name0+".txt") 100 | print(pos_path0) 101 | pos0 = read_pts(pos_path0) 102 | z0 = model.forward_single_image( 103 | imgs[0, 0].cuda().unsqueeze(0), [pos0], True)[0] 104 | cnt_0 = len(pos0) 105 | cum_cnt = cnt_0 106 | cnt_list.append(cnt_0) 107 | selected_idx = [v for v in range( 108 | interval, len(img_names), interval)] 109 | pos_lists = [] 110 | inflow_lists = [] 111 | pos_lists.append(pos0) 112 | inflow_lists.append([1 for _ in range(len(pos0))]) 113 | memory_features = [[z0[i]] for i in range(len(pos0))] 114 | ttl_list=[ttl for _ in range(len(pos0))] 115 | # if selected_idx[-1] != len(img_names)-1: 116 | # selected_idx.append(len(img_names)-1) 117 | for i in selected_idx: 118 | img_name = img_names[i][0] 119 | pos_path = os.path.join( 120 | "locater/results", video_name, img_name+".txt") 121 | pos = read_pts(pos_path) 122 | z = model.forward_single_image( 123 | imgs[0, i].cuda().unsqueeze(0), [pos], True)[0] 124 | z = F.normalize(z, dim=-1) 125 | C = np.zeros((len(pos), len(memory_features))) 126 | for idx, pre_z in enumerate(memory_features): 127 | pre_z = torch.stack(pre_z[-1:], dim=0).unsqueeze(0) 128 | pre_z = F.normalize(pre_z, dim=-1) 129 | sim_cost = torch.bmm(pre_z, z.unsqueeze(0).transpose(1, 2)) 130 | # sim_cost=1-similarity_cost(pre_z,z.unsqueeze(0)) 131 | sim_cost = sim_cost.cpu().numpy()[0] 132 | sim_cost = np.min(sim_cost, axis=0) 133 | C[:, idx] = sim_cost 134 | 135 | row_ind, col_ind = linear_sum_assignment(-C) 136 | sim_score = C[row_ind, col_ind] 137 | shared_mask = sim_score > threshold 138 | ori_shared_idx_list = col_ind[shared_mask] 139 | new_shared_idx_list = row_ind[shared_mask] 140 | outflow_idx_list = [i for i in range(len(pos)) if i not in row_ind[shared_mask]] 141 | 142 | for ori_idx, new_idx in zip(ori_shared_idx_list, new_shared_idx_list): 143 | memory_features[ori_idx].append(z[new_idx]) 144 | ttl_list[ori_idx]=ttl 145 | for idx in outflow_idx_list: 146 | memory_features.append([z[idx]]) 147 | ttl_list.append(ttl) 148 | pos_lists.append(pos) 149 | inflow_list = [] 150 | for j in range(len(pos)): 151 | if j in outflow_idx_list: 152 | inflow_list.append(1) 153 | else: 154 | inflow_list.append(0) 155 | inflow_lists.append(inflow_list) 156 | cum_cnt += len(outflow_idx_list) 157 | cnt_list.append(len(outflow_idx_list)) 158 | ttl_list=[ttl_list[idx]-1 for idx in range(len(ttl_list))] 159 | for idx in range(len(ttl_list)-1,-1,-1): 160 | if ttl_list[idx]==0: 161 | del memory_features[idx] 162 | del ttl_list[idx] 163 | 164 | # conver numpy to list 165 | pos_lists = [pos_lists[i].tolist() for i in range(len(pos_lists))] 166 | 167 | video_results[video_name] = { 168 | "video_num": cum_cnt, 169 | "first_frame_num": cnt_0, 170 | "cnt_list": cnt_list, 171 | "frame_num": len(img_names), 172 | "pos_lists": pos_lists, 173 | "inflow_lists": inflow_lists, 174 | 175 | } 176 | print(video_name, video_results[video_name]) 177 | with open("video_results_test.json", "w") as f: 178 | json.dump(video_results, f, indent=4) 179 | 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser("DenseMap Head ") 183 | parser.add_argument("--pair_config", default="configs/crowd_sense.json") 184 | parser.add_argument( 185 | "--pair_ckpt", default="outputs/weights/best.pth") 186 | 187 | parser.add_argument("--local_rank", type=int) 188 | args = parser.parse_args() 189 | 190 | if os.path.exists(args.pair_config): 191 | with open(args.pair_config, "r") as f: 192 | pair_configs = json.load(f) 193 | pair_cfg = edict(pair_configs) 194 | 195 | strtime = time.strftime('%Y%m%d%H%M') + "_" + os.path.basename( 196 | args.pair_config)[:-5] 197 | 198 | output_path = os.path.join(pair_cfg.Misc.tensorboard_dir, strtime) 199 | 200 | pair_cfg.Misc.tensorboard_dir = output_path 201 | pair_ckpt = args.pair_ckpt 202 | main(pair_cfg, pair_ckpt) 203 | -------------------------------------------------------------------------------- /misc/saver_builder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | from .tools import is_main_process 5 | class Saver(): 6 | def __init__(self, args) -> None: 7 | self.save_dir=args.save_dir 8 | self.save_interval=args.save_interval 9 | if not os.path.exists(self.save_dir): 10 | os.makedirs(self.save_dir,exist_ok=True) 11 | self.save_best=args.save_best 12 | self.save_start_epoch = args.save_start_epoch 13 | self.min_value=0 14 | self.max_value=1e10 15 | self.reverse=args.reverse 16 | self.metric=args.metric 17 | 18 | def save(self, model, optimizer, scheduler, filename, epoch, stats={}): 19 | if is_main_process(): 20 | torch.save({ 21 | 'model': model.state_dict(), 22 | 'optimizer': optimizer.state_dict(), 23 | 'scheduler': scheduler.state_dict(), 24 | 'epoch': epoch, 25 | 'states':stats 26 | },os.path.join(self.save_dir, filename)) 27 | 28 | def save_inter(self, model, optimizer, scheduler, name, epoch, stats={}): 29 | if epoch % self.save_interval == 0 and self.save_start_epoch <= epoch: 30 | self.save(model, optimizer, scheduler, name, epoch, stats) 31 | 32 | def save_on_master(self, model, optimizer, scheduler, epoch, stats={}): 33 | if is_main_process(): 34 | self.save_inter(model, optimizer, scheduler, f"checkpoint{epoch:04}.pth", epoch, stats) 35 | 36 | if self.save_best and self.save_start_epoch <= epoch: 37 | if self.reverse and stats.test_stats[self.metric] > self.min_value: 38 | self.min_value=max(self.min_value,stats.test_stats[self.metric]) 39 | self.save(model, optimizer, scheduler, f"best.pth", epoch, stats) 40 | elif not self.reverse and stats.test_stats[self.metric] < self.max_value: 41 | self.max_value=min(self.max_value,stats.test_stats[self.metric]) 42 | self.save(model, optimizer, scheduler, f"best.pth", epoch, stats) 43 | 44 | def save_last(self, model, optimizer, scheduler, epoch, stats={}): 45 | if is_main_process(): 46 | self.save(model, optimizer, scheduler, f"checkpoint_last.pth", epoch,stats) -------------------------------------------------------------------------------- /misc/tools.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import os 4 | import pickle 5 | import random 6 | import subprocess 7 | import time 8 | from collections import defaultdict, deque 9 | from typing import List, Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | import torch.nn as nn 15 | # needed due to empty tensor bug in pytorch and torchvision 0.5 16 | import torchvision 17 | from torch import Tensor 18 | 19 | def set_randomseed(seed): 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], 50 | dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | try: 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | except: 58 | print(f"Warning: {self.count} {self.total} is not an integer") 59 | 60 | self.count = 1 61 | self.total = -1 62 | 63 | @property 64 | def median(self): 65 | d = torch.tensor(list(self.deque)) 66 | return d.median().item() 67 | 68 | @property 69 | def avg(self): 70 | d = torch.tensor(list(self.deque), dtype=torch.float32) 71 | return d.mean().item() 72 | 73 | @property 74 | def global_avg(self): 75 | return self.total / self.count 76 | @property 77 | def max(self): 78 | return max(self.deque) 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | def __str__(self): 85 | return self.fmt.format( 86 | median=self.median, 87 | avg=self.avg, 88 | global_avg=self.global_avg, 89 | max=self.max, 90 | value=self.value) 91 | 92 | 93 | def all_gather(data): 94 | """ 95 | Run all_gather on arbitrary picklable data (not necessarily tensors) 96 | Args: 97 | data: any picklable object 98 | Returns: 99 | list[data]: list of data gathered from each rank 100 | """ 101 | world_size = get_world_size() 102 | if world_size == 1: 103 | return [data] 104 | 105 | # serialized to a Tensor 106 | buffer = pickle.dumps(data) 107 | storage = torch.ByteStorage.from_buffer(buffer) 108 | tensor = torch.ByteTensor(storage).to("cuda") 109 | 110 | # obtain Tensor size of each rank 111 | local_size = torch.tensor([tensor.numel()], device="cuda") 112 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 113 | dist.all_gather(size_list, local_size) 114 | size_list = [int(size.item()) for size in size_list] 115 | max_size = max(size_list) 116 | 117 | # receiving Tensor from all ranks 118 | # we pad the tensor because torch all_gather does not support 119 | # gathering tensors of different shapes 120 | tensor_list = [] 121 | for _ in size_list: 122 | tensor_list.append(torch.empty( 123 | (max_size,), dtype=torch.uint8, device="cuda")) 124 | if local_size != max_size: 125 | padding = torch.empty(size=(max_size - local_size,), 126 | dtype=torch.uint8, device="cuda") 127 | tensor = torch.cat((tensor, padding), dim=0) 128 | dist.all_gather(tensor_list, tensor) 129 | 130 | data_list = [] 131 | for size, tensor in zip(size_list, tensor_list): 132 | buffer = tensor.cpu().numpy().tobytes()[:size] 133 | data_list.append(pickle.loads(buffer)) 134 | 135 | return data_list 136 | 137 | 138 | def reduce_dict(input_dict, average=True): 139 | """ 140 | Args: 141 | input_dict (dict): all the values will be reduced 142 | average (bool): whether to do average or sum 143 | Reduce the values in the dictionary from all processes so that all processes 144 | have the averaged results. Returns a dict with the same fields as 145 | input_dict, after reduction. 146 | """ 147 | world_size = get_world_size() 148 | if world_size < 2: 149 | return input_dict 150 | with torch.no_grad(): 151 | names = [] 152 | values = [] 153 | # sort the keys so that they are consistent across processes 154 | for k in sorted(input_dict.keys()): 155 | names.append(k) 156 | values.append(input_dict[k]) 157 | values = torch.stack(values, dim=0) 158 | dist.all_reduce(values) 159 | if average: 160 | values /= world_size 161 | reduced_dict = {k: v for k, v in zip(names, values)} 162 | return reduced_dict 163 | 164 | 165 | class MetricLogger(object): 166 | def __init__(self, args): 167 | self.meters = defaultdict(SmoothedValue) 168 | self.delimiter = args.delimiter 169 | self.print_freq = args.print_freq 170 | self.header = args.header 171 | 172 | def update(self, **kwargs): 173 | for k, v in kwargs.items(): 174 | if isinstance(v, torch.Tensor): 175 | v = v.item() 176 | assert isinstance(v, (float, int)) 177 | self.meters[k].update(v) 178 | 179 | def __getattr__(self, attr): 180 | if attr in self.meters: 181 | return self.meters[attr] 182 | if attr in self.__dict__: 183 | return self.__dict__[attr] 184 | raise AttributeError("'{}' object has no attribute '{}'".format( 185 | type(self).__name__, attr)) 186 | 187 | def __str__(self): 188 | loss_str = [] 189 | for name, meter in self.meters.items(): 190 | loss_str.append( 191 | "{}: {}".format(name, str(meter)) 192 | ) 193 | return self.delimiter.join(loss_str) 194 | 195 | def synchronize_between_processes(self): 196 | for meter in self.meters.values(): 197 | meter.synchronize_between_processes() 198 | 199 | def add_meter(self, name, meter): 200 | self.meters[name] = meter 201 | 202 | def set_header(self, header): 203 | self.header = header 204 | 205 | def log_every(self, iterable): 206 | i = 1 207 | 208 | start_time = time.time() 209 | end = time.time() 210 | iter_time = SmoothedValue(fmt='{avg:.4f}') 211 | data_time = SmoothedValue(fmt='{avg:.4f}') 212 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 213 | if torch.cuda.is_available(): 214 | log_msg = self.delimiter.join([ 215 | self.header, 216 | '[{0' + space_fmt + '}/{1}]', 217 | 'eta: {eta}', 218 | '{meters}', 219 | 'time: {time}', 220 | 'data: {data}', 221 | 'max mem: {memory:.0f}' 222 | ]) 223 | else: 224 | log_msg = self.delimiter.join([ 225 | self.header, 226 | '[{0' + space_fmt + '}/{1}]', 227 | 'eta: {eta}', 228 | '{meters}', 229 | 'time: {time}', 230 | 'data: {data}' 231 | ]) 232 | MB = 1024.0 * 1024.0 233 | for obj in iterable: 234 | data_time.update(time.time() - end) 235 | yield obj 236 | iter_time.update(time.time() - end) 237 | if i % self.print_freq == 0 or i == len(iterable) - 1: 238 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 239 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 240 | if torch.cuda.is_available(): 241 | print(log_msg.format( 242 | i, len(iterable), eta=eta_string, 243 | meters=str(self), 244 | time=str(iter_time), data=str(data_time), 245 | memory=torch.cuda.max_memory_allocated() / MB)) 246 | else: 247 | print(log_msg.format( 248 | i, len(iterable), eta=eta_string, 249 | meters=str(self), 250 | time=str(iter_time), data=str(data_time))) 251 | i += 1 252 | end = time.time() 253 | total_time = time.time() - start_time 254 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 255 | print('{} Total time: {} ({:.4f} s / it)'.format( 256 | self.header, total_time_str, total_time / len(iterable))) 257 | 258 | 259 | def get_sha(): 260 | cwd = os.path.dirname(os.path.abspath(__file__)) 261 | 262 | def _run(command): 263 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 264 | sha = 'N/A' 265 | diff = "clean" 266 | branch = 'N/A' 267 | try: 268 | sha = _run(['git', 'rev-parse', 'HEAD']) 269 | subprocess.check_output(['git', 'diff'], cwd=cwd) 270 | diff = _run(['git', 'diff-index', 'HEAD']) 271 | diff = "has uncommited changes" if diff else "clean" 272 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 273 | except Exception: 274 | pass 275 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 276 | return message 277 | 278 | 279 | def collate_fn(batch): 280 | batch = list(zip(*batch)) 281 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 282 | return tuple(batch) 283 | 284 | 285 | def _max_by_axis(the_list): 286 | # type: (List[List[int]]) -> List[int] 287 | maxes = the_list[0] 288 | for sublist in the_list[1:]: 289 | for index, item in enumerate(sublist): 290 | maxes[index] = max(maxes[index], item) 291 | return maxes 292 | 293 | 294 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 295 | # TODO make this more general 296 | if tensor_list[0].ndim == 3: 297 | # TODO make it support different-sized images 298 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 299 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 300 | batch_shape = [len(tensor_list)] + max_size 301 | b, c, h, w = batch_shape 302 | dtype = tensor_list[0].dtype 303 | device = tensor_list[0].device 304 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 305 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 306 | for img, pad_img, m in zip(tensor_list, tensor, mask): 307 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 308 | m[: img.shape[1], :img.shape[2]] = False 309 | else: 310 | raise ValueError('not supported') 311 | return NestedTensor(tensor, mask) 312 | 313 | 314 | class NestedTensor(object): 315 | def __init__(self, tensors, mask: Optional[Tensor]): 316 | self.tensors = tensors 317 | self.mask = mask 318 | 319 | def to(self, device, non_blocking=False): 320 | ## type: (Device) -> NestedTensor # noqa 321 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 322 | mask = self.mask 323 | if mask is not None: 324 | assert mask is not None 325 | cast_mask = mask.to(device, non_blocking=non_blocking) 326 | else: 327 | cast_mask = None 328 | return NestedTensor(cast_tensor, cast_mask) 329 | 330 | def record_stream(self, *args, **kwargs): 331 | self.tensors.record_stream(*args, **kwargs) 332 | if self.mask is not None: 333 | self.mask.record_stream(*args, **kwargs) 334 | 335 | def decompose(self): 336 | return self.tensors, self.mask 337 | 338 | def __repr__(self): 339 | return str(self.tensors) 340 | 341 | 342 | def setup_for_distributed(is_master): 343 | """ 344 | This function disables printing when not in master process 345 | """ 346 | import builtins as __builtin__ 347 | builtin_print = __builtin__.print 348 | 349 | def print(*args, **kwargs): 350 | force = kwargs.pop('force', False) 351 | if is_master or force: 352 | builtin_print(*args, **kwargs) 353 | 354 | __builtin__.print = print 355 | 356 | 357 | def is_dist_avail_and_initialized(): 358 | if not dist.is_available(): 359 | return False 360 | if not dist.is_initialized(): 361 | return False 362 | return True 363 | 364 | 365 | def get_world_size(): 366 | if not is_dist_avail_and_initialized(): 367 | return 1 368 | return dist.get_world_size() 369 | 370 | 371 | def get_rank(): 372 | if not is_dist_avail_and_initialized(): 373 | return 0 374 | return dist.get_rank() 375 | 376 | 377 | def get_local_size(): 378 | if not is_dist_avail_and_initialized(): 379 | return 1 380 | return int(os.environ['LOCAL_SIZE']) 381 | 382 | 383 | def get_local_rank(): 384 | if not is_dist_avail_and_initialized(): 385 | return 0 386 | return int(os.environ['LOCAL_RANK']) 387 | 388 | 389 | def is_main_process(): 390 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 391 | return int(os.environ['RANK']) == 0 392 | elif "SLURM_PROCID" in os.environ: 393 | return int(os.environ["SLURM_PROCID"]) == 0 394 | else: 395 | return True 396 | 397 | 398 | def save_on_master(*args, **kwargs): 399 | if is_main_process(): 400 | torch.save(*args, **kwargs) 401 | 402 | 403 | def init_distributedrun_mode(args): 404 | if "LOCAL_RANK" in os.environ: 405 | args.local_rank = int(os.environ["LOCAL_RANK"]) 406 | args.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 407 | 408 | 409 | def init_distributed_mode(args): 410 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 411 | args.rank = int(os.environ["RANK"]) 412 | args.world_size = int(os.environ['WORLD_SIZE']) 413 | args.gpu = int(os.environ['LOCAL_RANK']) 414 | args.dist_url = 'env://' 415 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 416 | elif 'SLURM_PROCID' in os.environ: 417 | proc_id = int(os.environ['SLURM_PROCID']) 418 | ntasks = int(os.environ['SLURM_NTASKS']) 419 | node_list = os.environ['SLURM_NODELIST'] 420 | num_gpus = torch.cuda.device_count() 421 | addr = subprocess.getoutput( 422 | 'scontrol show hostname {} | head -n1'.format(node_list)) 423 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 424 | os.environ['MASTER_ADDR'] = addr 425 | os.environ['WORLD_SIZE'] = str(ntasks) 426 | os.environ['RANK'] = str(proc_id) 427 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 428 | os.environ['LOCAL_SIZE'] = str(num_gpus) 429 | args.dist_url = 'env://' 430 | args.world_size = ntasks 431 | args.rank = proc_id 432 | args.gpu = proc_id % num_gpus 433 | else: 434 | print('Not using distributed mode') 435 | args.distributed = False 436 | args.gpu=0 437 | return 438 | 439 | args.distributed = True 440 | 441 | torch.cuda.set_device(args.gpu) 442 | args.dist_backend = 'nccl' 443 | print('| distributed init (rank {}): {}'.format( 444 | args.rank, args.dist_url), flush=True) 445 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 446 | world_size=args.world_size, rank=args.rank) 447 | torch.distributed.barrier() 448 | setup_for_distributed(args.rank == 0) 449 | 450 | 451 | @torch.no_grad() 452 | def accuracy(output, target, topk=(1,)): 453 | """Computes the precision@k for the specified values of k""" 454 | if target.numel() == 0: 455 | return [torch.zeros([], device=output.device)] 456 | maxk = max(topk) 457 | batch_size = target.size(0) 458 | 459 | _, pred = output.topk(maxk, 1, True, True) 460 | pred = pred.t() 461 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 462 | 463 | res = [] 464 | for k in topk: 465 | correct_k = correct[:k].view(-1).float().sum(0) 466 | res.append(correct_k.mul_(100.0 / batch_size)) 467 | return res 468 | 469 | 470 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 471 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 472 | """ 473 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 474 | This will eventually be supported natively by PyTorch, and this 475 | class can go away. 476 | """ 477 | # if float(torchvision.__version__[:3]) < 0.7: 478 | # if input.numel() > 0: 479 | # return torch.nn.functional.interpolate( 480 | # input, size, scale_factor, mode, align_corners 481 | # ) 482 | 483 | # output_shape = _output_size(2, input, size, scale_factor) 484 | # output_shape = list(input.shape[:-2]) + list(output_shape) 485 | # if float(torchvision.__version__[:3]) < 0.5: 486 | # return _NewEmptyTensorOp.apply(input, output_shape) 487 | # return _new_empty_tensor(input, output_shape) 488 | # else: 489 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 490 | 491 | 492 | def get_total_grad_norm(parameters, norm_type=2): 493 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 494 | norm_type = float(norm_type) 495 | device = parameters[0].grad.device 496 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 497 | norm_type) 498 | return total_norm 499 | 500 | 501 | def inverse_sigmoid(x, eps=1e-5): 502 | x = x.clamp(min=0, max=1) 503 | x1 = x.clamp(min=eps) 504 | x2 = (1 - x).clamp(min=eps) 505 | return torch.log(x1/x2) 506 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from typing import Dict, List 6 | from timm import create_model 7 | from timm.models import features 8 | import torch 9 | 10 | class Backbone(nn.Module): 11 | def __init__(self, name: str,pretrained:bool,out_indices:List[int], train_backbone: bool): 12 | super(Backbone,self).__init__() 13 | backbone=create_model(name,pretrained=pretrained,features_only=True, out_indices=out_indices) 14 | self.train_backbone = train_backbone 15 | self.backbone=backbone 16 | self.out_indices=out_indices 17 | if not self.train_backbone: 18 | for name, parameter in self.backbone.named_parameters(): 19 | parameter.requires_grad_(False) 20 | def forward(self,x): 21 | x=self.backbone(x) 22 | for i in range(len(x)): 23 | x[i]=F.relu(x[i]) 24 | 25 | return x 26 | 27 | @property 28 | def feature_info(self): 29 | return features._get_feature_info(self.backbone,out_indices=self.out_indices) 30 | 31 | 32 | def build_backbone(): 33 | backbone = Backbone("convnext_small_384_in22ft1k", True, [3], True) 34 | # backbone = Backbone("convnext_small_384_in22ft1k", True, [2], True) 35 | 36 | # backbone = Backbone("resnet50.a3_in1k", True, [2], True) 37 | 38 | return backbone 39 | 40 | if __name__=="__main__": 41 | model=build_backbone() 42 | x=torch.randn(1,3,224,224) 43 | z=model(x) 44 | for f in z: 45 | print(f.shape) 46 | -------------------------------------------------------------------------------- /models/cropper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | 24 | self.ot_loss = GML() 25 | 26 | def forward(self, input): 27 | x = input["image_pair"] 28 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 29 | x1 = x[:, 0:3, :, :] 30 | x2 = x[:, 3:6, :, :] 31 | ref_point1 = ref_points[:, 0, ...] 32 | ref_point2 = ref_points[:, 1, ...] 33 | 34 | z1_lists=[] 35 | z2_lists=[] 36 | for b in range(x1.shape[0]): 37 | z1_list=[] 38 | z2_list=[] 39 | for pt1,pt2 in zip(ref_point1[b],ref_point2[b]): 40 | z1=self.get_crops(x1[b].unsqueeze(0),pt1) 41 | z2=self.get_crops(x2[b].unsqueeze(0),pt2) 42 | z1=F.interpolate(z1,(224,224)) 43 | z2=F.interpolate(z2,(224,224)) 44 | z1_list.append(z1) 45 | z2_list.append(z2) 46 | z1=torch.cat(z1_list,dim=0) 47 | z2=torch.cat(z2_list,dim=0) 48 | z1=self.backbone(z1)[0].flatten(2).flatten(1) 49 | z2=self.backbone(z2)[0].flatten(2).flatten(1) 50 | z1_lists.append(z1) 51 | z2_lists.append(z2) 52 | z1=torch.stack(z1_lists,dim=0) 53 | z2=torch.stack(z2_lists,dim=0) 54 | return z1,z2 55 | 56 | def get_crops(self, z, pt, window_size=64): 57 | h, w = z.shape[-2], z.shape[-1] 58 | x_min = pt[0]*w-window_size//2 59 | x_max = pt[0]*w+window_size//2 60 | y_min = pt[1]*h-window_size//2 61 | y_max = pt[1]*h+window_size//2 62 | x_min, x_max, y_min, y_max = int(x_min), int( 63 | x_max), int(y_min), int(y_max) 64 | x_min = max(0, x_min) 65 | x_max = min(w, x_max) 66 | y_min = max(0, y_min) 67 | y_max = min(h, y_max) 68 | z = z[:, :, y_min:y_max, x_min:x_max] 69 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 70 | # z = z + pos_emb 71 | # z=F.adaptive_avg_pool2d(z,(1,1)) 72 | # z=z.squeeze(3).squeeze(2) 73 | return z 74 | 75 | def loss(self, features1, features2): 76 | loss = self.ot_loss(features1, features2) 77 | loss_dict = {} 78 | loss_dict["all"] = loss 79 | return loss_dict 80 | 81 | 82 | def build_model(): 83 | model = Model() 84 | return model 85 | -------------------------------------------------------------------------------- /models/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .sim_ot import ot_similarity 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | def pos2posemb2d(pos, num_pos_feats=128, temperature=1000): 12 | scale = 2 * math.pi 13 | pos = pos * scale 14 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 15 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 16 | pos_x = pos[..., 0, None] / dim_t 17 | pos_y = pos[..., 1, None] / dim_t 18 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 19 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 20 | posemb = torch.cat((pos_y, pos_x), dim=-1) 21 | return posemb 22 | 23 | class PositionEmbeddingSine(nn.Module): 24 | """ 25 | This is a more standard version of the position embedding, very similar to the one 26 | used by the Attention is all you need paper, generalized to work on images. 27 | """ 28 | def __init__(self, num_pos_feats=128, temperature=1000, normalize=False, scale=None): 29 | super().__init__() 30 | self.num_pos_feats = num_pos_feats 31 | self.temperature = temperature 32 | self.normalize = normalize 33 | if scale is not None and normalize is False: 34 | raise ValueError("normalize should be True if scale is passed") 35 | if scale is None: 36 | scale = 2 * math.pi 37 | self.scale = scale 38 | 39 | def forward(self, x): 40 | mask = torch.ones(x.shape[0],x.shape[1],x.shape[2]).cuda().bool() 41 | assert mask is not None 42 | not_mask = ~mask 43 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 44 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 45 | if self.normalize: 46 | eps = 1e-6 47 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 48 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 49 | 50 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 51 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 52 | 53 | pos_x = x_embed[:, :, :, None] / dim_t 54 | pos_y = y_embed[:, :, :, None] / dim_t 55 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 57 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 58 | return pos 59 | 60 | def _get_activation_fn(activation): 61 | """Return an activation function given a string""" 62 | if activation == "relu": 63 | return F.relu 64 | if activation == "gelu": 65 | return F.gelu 66 | if activation == "glu": 67 | return F.glu 68 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 69 | 70 | def _get_clones(module, N): 71 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 72 | 73 | 74 | class FFN(nn.Module): 75 | 76 | def __init__(self, d_model=256, d_ffn=1024, dropout=0.): 77 | super().__init__() 78 | self.linear1 = nn.Linear(d_model, d_ffn) 79 | self.activation = nn.ReLU() 80 | self.dropout2 = nn.Dropout(dropout) 81 | self.linear2 = nn.Linear(d_ffn, d_model) 82 | self.dropout3 = nn.Dropout(dropout) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | 85 | def forward(self, src): 86 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 87 | src = src + self.dropout3(src2) 88 | src = self.norm2(src) 89 | return src 90 | 91 | class Transformer(nn.Module): 92 | 93 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 94 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 95 | activation="relu", normalize_before=False, 96 | return_intermediate_dec=False): 97 | super().__init__() 98 | 99 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 100 | dropout, activation, normalize_before) 101 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 102 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 103 | 104 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 105 | dropout, activation, normalize_before) 106 | decoder_norm = nn.LayerNorm(d_model) 107 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 108 | return_intermediate=return_intermediate_dec) 109 | 110 | self._reset_parameters() 111 | 112 | self.d_model = d_model 113 | self.nhead = nhead 114 | 115 | def _reset_parameters(self): 116 | for p in self.parameters(): 117 | if p.dim() > 1: 118 | nn.init.xavier_uniform_(p) 119 | 120 | def forward(self, src, mask, query_embed, pos_embed): 121 | # flatten NxCxHxW to HWxNxC 122 | bs, c, h, w = src.shape 123 | src = src.flatten(2).permute(2, 0, 1) 124 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 125 | mask = mask.flatten(1) 126 | print(src.dtype) 127 | tgt = torch.zeros_like(query_embed) 128 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 129 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 130 | pos=pos_embed, query_pos=query_embed) 131 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 132 | 133 | 134 | class TransformerEncoder(nn.Module): 135 | 136 | def __init__(self, encoder_layer, num_layers, norm=None): 137 | super().__init__() 138 | self.layers = _get_clones(encoder_layer, num_layers) 139 | self.num_layers = num_layers 140 | self.norm = norm 141 | 142 | def forward(self, src, 143 | mask: Optional[Tensor] = None, 144 | src_key_padding_mask: Optional[Tensor] = None, 145 | pos: Optional[Tensor] = None): 146 | output = src 147 | 148 | for layer in self.layers: 149 | output = layer(output, src_mask=mask, 150 | src_key_padding_mask=src_key_padding_mask, pos=pos) 151 | 152 | if self.norm is not None: 153 | output = self.norm(output) 154 | 155 | return output 156 | 157 | 158 | class TransformerDecoder(nn.Module): 159 | 160 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 161 | super().__init__() 162 | self.layers = _get_clones(decoder_layer, num_layers) 163 | self.num_layers = num_layers 164 | self.norm = norm 165 | self.return_intermediate = return_intermediate 166 | 167 | def forward(self, tgt, memory, 168 | tgt_mask: Optional[Tensor] = None, 169 | memory_mask: Optional[Tensor] = None, 170 | tgt_key_padding_mask: Optional[Tensor] = None, 171 | memory_key_padding_mask: Optional[Tensor] = None, 172 | pos: Optional[Tensor] = None, 173 | query_pos: Optional[Tensor] = None): 174 | output = tgt 175 | 176 | intermediate = [] 177 | 178 | for layer in self.layers: 179 | output = layer(output, memory, tgt_mask=tgt_mask, 180 | memory_mask=memory_mask, 181 | tgt_key_padding_mask=tgt_key_padding_mask, 182 | memory_key_padding_mask=memory_key_padding_mask, 183 | pos=pos, query_pos=query_pos) 184 | if self.return_intermediate: 185 | intermediate.append(self.norm(output)) 186 | 187 | if self.norm is not None: 188 | output = self.norm(output) 189 | if self.return_intermediate: 190 | intermediate.pop() 191 | intermediate.append(output) 192 | 193 | if self.return_intermediate: 194 | return torch.stack(intermediate) 195 | 196 | return output.unsqueeze(0) 197 | 198 | 199 | class TransformerEncoderLayer(nn.Module): 200 | 201 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 202 | activation="relu", normalize_before=False): 203 | super().__init__() 204 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 205 | # Implementation of Feedforward model 206 | self.linear1 = nn.Linear(d_model, dim_feedforward) 207 | self.dropout = nn.Dropout(dropout) 208 | self.linear2 = nn.Linear(dim_feedforward, d_model) 209 | 210 | self.norm1 = nn.LayerNorm(d_model) 211 | self.norm2 = nn.LayerNorm(d_model) 212 | self.dropout1 = nn.Dropout(dropout) 213 | self.dropout2 = nn.Dropout(dropout) 214 | 215 | self.activation = _get_activation_fn(activation) 216 | self.normalize_before = normalize_before 217 | 218 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 219 | return tensor if pos is None else tensor + pos 220 | 221 | def forward_post(self, 222 | src, 223 | src_mask: Optional[Tensor] = None, 224 | src_key_padding_mask: Optional[Tensor] = None, 225 | pos: Optional[Tensor] = None): 226 | print(src.shape,pos.shape) 227 | q = k = self.with_pos_embed(src, pos) 228 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 229 | key_padding_mask=src_key_padding_mask)[0] 230 | src = src + self.dropout1(src2) 231 | src = self.norm1(src) 232 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 233 | src = src + self.dropout2(src2) 234 | src = self.norm2(src) 235 | return src 236 | 237 | def forward_pre(self, src, 238 | src_mask: Optional[Tensor] = None, 239 | src_key_padding_mask: Optional[Tensor] = None, 240 | pos: Optional[Tensor] = None): 241 | src2 = self.norm1(src) 242 | q = k = self.with_pos_embed(src2, pos) 243 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 244 | key_padding_mask=src_key_padding_mask)[0] 245 | src = src + self.dropout1(src2) 246 | src2 = self.norm2(src) 247 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 248 | src = src + self.dropout2(src2) 249 | return src 250 | 251 | def forward(self, src, 252 | src_mask: Optional[Tensor] = None, 253 | src_key_padding_mask: Optional[Tensor] = None, 254 | pos: Optional[Tensor] = None): 255 | if self.normalize_before: 256 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 257 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 258 | 259 | 260 | class TransformerDecoderLayer(nn.Module): 261 | 262 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 263 | activation="relu", normalize_before=False): 264 | super().__init__() 265 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 266 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 267 | # Implementation of Feedforward model 268 | self.linear1 = nn.Linear(d_model, dim_feedforward) 269 | self.dropout = nn.Dropout(dropout) 270 | self.linear2 = nn.Linear(dim_feedforward, d_model) 271 | 272 | self.norm1 = nn.LayerNorm(d_model) 273 | self.norm2 = nn.LayerNorm(d_model) 274 | self.norm3 = nn.LayerNorm(d_model) 275 | self.dropout1 = nn.Dropout(dropout) 276 | self.dropout2 = nn.Dropout(dropout) 277 | self.dropout3 = nn.Dropout(dropout) 278 | 279 | self.activation = _get_activation_fn(activation) 280 | self.normalize_before = normalize_before 281 | 282 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 283 | return tensor if pos is None else tensor + pos 284 | 285 | def forward_post(self, tgt, memory, 286 | tgt_mask: Optional[Tensor] = None, 287 | memory_mask: Optional[Tensor] = None, 288 | tgt_key_padding_mask: Optional[Tensor] = None, 289 | memory_key_padding_mask: Optional[Tensor] = None, 290 | pos: Optional[Tensor] = None, 291 | query_pos: Optional[Tensor] = None): 292 | q = k = self.with_pos_embed(tgt, query_pos) 293 | print(q.dtype,k.dtype,tgt.dtype) 294 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 295 | key_padding_mask=tgt_key_padding_mask)[0] 296 | tgt = tgt + self.dropout1(tgt2) 297 | tgt = self.norm1(tgt) 298 | print(tgt.shape,query_pos.shape,memory.shape,pos.shape) 299 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 300 | key=self.with_pos_embed(memory, pos), 301 | value=memory, attn_mask=memory_mask, 302 | key_padding_mask=memory_key_padding_mask)[0] 303 | tgt = tgt + self.dropout2(tgt2) 304 | tgt = self.norm2(tgt) 305 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 306 | tgt = tgt + self.dropout3(tgt2) 307 | tgt = self.norm3(tgt) 308 | return tgt 309 | 310 | def forward_pre(self, tgt, memory, 311 | tgt_mask: Optional[Tensor] = None, 312 | memory_mask: Optional[Tensor] = None, 313 | tgt_key_padding_mask: Optional[Tensor] = None, 314 | memory_key_padding_mask: Optional[Tensor] = None, 315 | pos: Optional[Tensor] = None, 316 | query_pos: Optional[Tensor] = None): 317 | tgt2 = self.norm1(tgt) 318 | q = k = self.with_pos_embed(tgt2, query_pos) 319 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 320 | key_padding_mask=tgt_key_padding_mask)[0] 321 | tgt = tgt + self.dropout1(tgt2) 322 | tgt2 = self.norm2(tgt) 323 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 324 | key=self.with_pos_embed(memory, pos), 325 | value=memory, attn_mask=memory_mask, 326 | key_padding_mask=memory_key_padding_mask)[0] 327 | tgt = tgt + self.dropout2(tgt2) 328 | tgt2 = self.norm3(tgt) 329 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 330 | tgt = tgt + self.dropout3(tgt2) 331 | return tgt 332 | 333 | def forward(self, tgt, memory, 334 | tgt_mask: Optional[Tensor] = None, 335 | memory_mask: Optional[Tensor] = None, 336 | tgt_key_padding_mask: Optional[Tensor] = None, 337 | memory_key_padding_mask: Optional[Tensor] = None, 338 | pos: Optional[Tensor] = None, 339 | query_pos: Optional[Tensor] = None): 340 | if self.normalize_before: 341 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 342 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 343 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 344 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 345 | 346 | 347 | def _get_clones(module, N): 348 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 349 | class Model(nn.Module): 350 | def __init__(self,height=720,width=1280,stride=16,num_feature_levels=3,num_channels=[96,192,384],hidden_dim=128) -> None: 351 | super().__init__() 352 | self.backbone = build_backbone() 353 | self.transformer=Transformer(d_model=hidden_dim,nhead=8,num_encoder_layers=6,num_decoder_layers=6,dim_feedforward=2048,dropout=0.1,normalize_before=False,return_intermediate_dec=False) 354 | self.ot_loss = ot_similarity() 355 | 356 | input_proj_list=[] 357 | for i in range(num_feature_levels): 358 | input_proj_list.append(nn.Sequential( 359 | nn.Conv2d(num_channels[i],hidden_dim,kernel_size=1,stride=1,padding=0), 360 | nn.GroupNorm(32,hidden_dim), 361 | nn.GELU(), 362 | )) 363 | for j in range(num_feature_levels-1-i): 364 | input_proj_list[i]+=nn.Sequential( 365 | nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=2,padding=1), 366 | nn.GroupNorm(32,hidden_dim), 367 | nn.GELU(), 368 | ) 369 | self.input_fuse=nn.Sequential( 370 | nn.Conv2d(hidden_dim*num_feature_levels,hidden_dim,kernel_size=1,stride=1,padding=0), 371 | nn.GroupNorm(32,hidden_dim), 372 | nn.GELU(), 373 | ) 374 | self.input_proj=nn.ModuleList(input_proj_list) 375 | for proj in self.input_proj: 376 | nn.init.xavier_uniform_(proj[0].weight, gain=nn.init.calculate_gain('relu')) 377 | nn.init.constant_(proj[0].bias, 0) 378 | self.hidden_dim=hidden_dim 379 | self.pos_embedder=PositionEmbeddingSine(hidden_dim//2,normalize=True) 380 | def forward(self, x, ref_points): 381 | x1=x[:,0:3,:,:] 382 | x2=x[:,3:6,:,:] 383 | ref_point1=ref_points[:,0,...] 384 | ref_point2=ref_points[:,1,...] 385 | 386 | z1 = self.backbone(x1) 387 | z2 = self.backbone(x2) 388 | z1_list=[] 389 | z2_list=[] 390 | for i in range(len(z1)): 391 | z1_list.append(self.input_proj[i](z1[i])) 392 | z2_list.append(self.input_proj[i](z2[i])) 393 | z1=torch.cat(z1_list,dim=1) 394 | z1=self.input_fuse(z1) 395 | z2=torch.cat(z2_list,dim=1) 396 | z2=self.input_fuse(z2) 397 | query_embed1 = pos2posemb2d(ref_point1, num_pos_feats=self.hidden_dim//2).float() 398 | query_embed2 = pos2posemb2d(ref_point2, num_pos_feats=self.hidden_dim//2).float() 399 | pos_emd_1=self.pos_embedder(z1_list[0].permute(0,2,3,1)) 400 | pos_emd_2=self.pos_embedder(z2_list[0].permute(0,2,3,1)) 401 | mask1=mask2=torch.ones((z1.shape[0],z1.shape[2],z1.shape[3])).cuda().bool() 402 | print(z1.shape,pos_emd_1.shape,query_embed1.shape) 403 | features1=self.transformer(z1,mask1,query_embed1,pos_emd_1) 404 | features2=self.transformer(z2,mask2,query_embed2,pos_emd_2) 405 | features1=F.relu(features1) 406 | features2=F.relu(features2) 407 | return features1,features2 408 | def loss(self,features1,features2): 409 | loss=self.ot_loss(features1,features2) 410 | return loss 411 | def build_model(): 412 | model=Model() 413 | return model 414 | 415 | if __name__=="__main__": 416 | transformer=Transformer(64,64).cuda().half() 417 | x=torch.rand(1,256,64,64).cuda().half() 418 | ref_pts=torch.rand(1,300,2).cuda().half() 419 | y=transformer(x,ref_pts) 420 | print(y.shape) -------------------------------------------------------------------------------- /models/fuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MS_CAM(nn.Module): 6 | ''' 7 | 单特征 进行通道加权,作用类似SE模块 8 | ''' 9 | 10 | def __init__(self, channels=64, r=4): 11 | super(MS_CAM, self).__init__() 12 | inter_channels = int(channels // r) 13 | 14 | self.local_att = nn.Sequential( 15 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 16 | nn.BatchNorm2d(inter_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 19 | nn.BatchNorm2d(channels), 20 | ) 21 | 22 | self.global_att = nn.Sequential( 23 | nn.AdaptiveAvgPool2d(1), 24 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 25 | nn.BatchNorm2d(inter_channels), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 28 | nn.BatchNorm2d(channels), 29 | ) 30 | 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | xl = self.local_att(x) 35 | xg = self.global_att(x) 36 | xlg = xl + xg 37 | wei = self.sigmoid(xlg) 38 | return x * wei 39 | 40 | class FFN(nn.Module): 41 | def __init__(self, in_chans, mid_chans, out_chans): 42 | super().__init__() 43 | self.conv1 = nn.Conv2d(in_chans, mid_chans, 1) 44 | self.act1 = nn.GELU() 45 | self.conv2 = nn.Conv2d(mid_chans, out_chans, 1) 46 | self.act2 = nn.GELU() 47 | 48 | def forward(self, x): 49 | x = self.act1(self.conv1(x)) 50 | x = self.act2(self.conv2(x)) 51 | return x 52 | 53 | 54 | class Head(nn.Module): 55 | def __init__(self): 56 | super(Head, self).__init__() 57 | self.conv1 =nn.Sequential( 58 | nn.Conv2d(192,96,1), 59 | nn.BatchNorm2d(96), 60 | nn.GELU(), 61 | ) 62 | self.conv2 = nn.Sequential( 63 | nn.Conv2d(384, 96, 1), 64 | nn.BatchNorm2d(96), 65 | nn.GELU(), 66 | ) 67 | self.conv3 = nn.Sequential( 68 | nn.Conv2d(768, 96, 1), 69 | nn.BatchNorm2d(96), 70 | nn.GELU(), 71 | ) 72 | self.ms_cam=MS_CAM(96+96+96+2) 73 | self.out1 = nn.Sequential( 74 | nn.Conv2d(96+96+96+2, 256, 1), 75 | nn.BatchNorm2d(256), 76 | nn.GELU(), 77 | nn.Conv2d(256, 128, 3,dilation=2,padding=2), 78 | nn.GELU(), 79 | nn.Conv2d(128, 64, 3,dilation=2,padding=2), 80 | nn.GELU(), 81 | nn.Conv2d(64, 1, 1), 82 | nn.GELU(), 83 | ) 84 | self._initialize_weights() 85 | 86 | def _initialize_weights(self): 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | nn.init.normal_(m.weight, std=0.01) 90 | if m.bias is not None: 91 | nn.init.constant_(m.bias, 0) 92 | elif isinstance(m, nn.BatchNorm2d): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | def forward(self, z,c1,c2): 97 | # print(x.shape) 98 | # torch.Size([1, 96, 256, 256]) 99 | # torch.Size([1, 192, 128, 128]) 100 | # torch.Size([1, 384, 64, 64]) 101 | # torch.Size([1, 768, 32, 32]) 102 | z1,z2=z 103 | x1, x2 , x3 = z1 104 | y1, y2 , y3 = z2 105 | x1 = self.conv1(torch.cat([x1,y1],dim=1)) 106 | x2 = self.conv2(torch.cat([x2,y2],dim=1)) 107 | x3 = self.conv3(torch.cat([x3,y3],dim=1)) 108 | # x3 = self.conv3(x3) 109 | # x4 = self.conv4(x4) 110 | 111 | x3= F.interpolate(x3, scale_factor=4) 112 | x2 = F.interpolate(x2, scale_factor=2) 113 | 114 | z=torch.cat([x1,x2,x3,c1,c2],dim=1) 115 | z=self.ms_cam(z) 116 | 117 | out1 = self.out1(z) 118 | return out1 119 | 120 | 121 | def build_head(): 122 | return Head() 123 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from mmcv.cnn import get_model_complexity_info 5 | from torch import nn 6 | from torch.amp import autocast 7 | from torch.nn import functional as F 8 | 9 | from .backbone import build_backbone 10 | from .head import build_head as build_locating_head 11 | from .fuse import build_head as build_fuse_head 12 | from .local_bl import build_loss, DrawDenseMap 13 | class Model(nn.Module): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | self.backbone = build_backbone() 17 | self.locating_head = build_locating_head() 18 | self.fuse_head=build_fuse_head() 19 | self.pt2dmap = DrawDenseMap(1,0,2) 20 | # self.get_model_complexity(input_shape=(6, 1280, 720)) 21 | 22 | @autocast("cuda") 23 | def forward(self, x): 24 | x1=x[:,0:3,:,:] 25 | x2=x[:,3:6,:,:] 26 | z1 = self.backbone(x1) 27 | z2 = self.backbone(x2) 28 | 29 | default_counting_map = self.locating_head([z1,z2]) 30 | duplicate_counting_map = self.locating_head([z2,z1]) 31 | fuse_counting_map=self.fuse_head([z1,z2],default_counting_map,duplicate_counting_map) 32 | return { 33 | "default_counting_map": F.interpolate(default_counting_map,scale_factor=2), 34 | "duplicate_counting_map": F.interpolate(duplicate_counting_map,scale_factor=2), 35 | "fuse_counting_map":F.interpolate(fuse_counting_map,scale_factor=2) 36 | } 37 | 38 | def get_model_complexity(self, input_shape): 39 | flops, params = get_model_complexity_info(self, input_shape) 40 | return flops, params 41 | 42 | 43 | @torch.no_grad() 44 | def _map2points(self,predict_counting_map,kernel,threshold,loc_kernel_size,loc_padding): 45 | device=predict_counting_map.device 46 | max_m=torch.max(predict_counting_map) 47 | threshold=max(0.1,threshold*max_m) 48 | low_resolution_map=F.interpolate(F.relu(predict_counting_map),scale_factor=0.5) 49 | H,W=low_resolution_map.shape[-2],low_resolution_map.shape[-1] 50 | 51 | unfolded_map=F.unfold(low_resolution_map,kernel_size=loc_kernel_size,padding=loc_padding) 52 | unfolded_max_idx=unfolded_map.max(dim=1,keepdim=True)[1] 53 | unfolded_max_mask=(unfolded_max_idx==loc_kernel_size**2//2).reshape(1,1,H,W) 54 | 55 | predict_cnt=F.conv2d(low_resolution_map,kernel,padding=loc_padding) 56 | predict_filter=(predict_cnt>threshold).float() 57 | predict_filter=predict_filter*unfolded_max_mask 58 | predict_filter=predict_filter.detach().cpu().numpy().astype(bool).reshape(H,W) 59 | 60 | pred_coord_weight=F.normalize(unfolded_map,p=1,dim=1) 61 | 62 | coord_h=torch.arange(H).reshape(-1,1).repeat(1,W).to(device).float() 63 | coord_w=torch.arange(W).reshape(1,-1).repeat(H,1).to(device).float() 64 | coord_h=coord_h.unsqueeze(0).unsqueeze(0) 65 | coord_w=coord_w.unsqueeze(0).unsqueeze(0) 66 | unfolded_coord_h=F.unfold(coord_h,kernel_size=loc_kernel_size,padding=loc_padding) 67 | pred_coord_h=(unfolded_coord_h*pred_coord_weight).sum(dim=1,keepdim=True).reshape(H,W).detach().cpu().numpy() 68 | unfolded_coord_w=F.unfold(coord_w,kernel_size=loc_kernel_size,padding=loc_padding) 69 | pred_coord_w=(unfolded_coord_w*pred_coord_weight).sum(dim=1,keepdim=True).reshape(H,W).detach().cpu().numpy() 70 | coord_h=pred_coord_h[predict_filter].reshape(-1,1) 71 | coord_w=pred_coord_w[predict_filter].reshape(-1,1) 72 | coord=np.concatenate([coord_w,coord_h],axis=1) 73 | 74 | pred_points=[[4*coord_w+loc_kernel_size/2.,4*coord_h+loc_kernel_size/2.] for coord_w,coord_h in coord] 75 | return pred_points 76 | 77 | @torch.no_grad() 78 | def forward_points(self, x, threshold=0.8,loc_kernel_size=3): 79 | assert loc_kernel_size%2==1 80 | assert x.shape[0]==1 81 | out_dict=self.forward(x) 82 | 83 | loc_padding=loc_kernel_size//2 84 | kernel=torch.ones(1,1,loc_kernel_size,loc_kernel_size).to(x.device).float() 85 | default_counting_map=out_dict["default_counting_map"].detach().float() 86 | duplicate_counting_map=out_dict["duplicate_counting_map"].detach().float() 87 | fuse_counting_map=out_dict["fuse_counting_map"].detach().float() 88 | default_points=self._map2points(default_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 89 | duplicate_points=self._map2points(duplicate_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 90 | fuse_points=self._map2points(fuse_counting_map,kernel,threshold,loc_kernel_size,loc_padding) 91 | out_dict["default_pts"]=torch.tensor(default_points).to(x.device).float() 92 | out_dict["duplicate_pts"]=torch.tensor(duplicate_points).to(x.device).float() 93 | out_dict["fuse_pts"]=torch.tensor(fuse_points).to(x.device).float() 94 | return out_dict 95 | 96 | def loss(self, out_dict, targets): 97 | map_loss=build_loss() 98 | loss_dict = {} 99 | device=out_dict["default_counting_map"].device 100 | gt_default_maps,gt_duplicate_maps,gt_fuse_maps=[],[],[] 101 | map_loss=map_loss.to(device) 102 | with torch.no_grad(): 103 | for idx in range(targets["gt_default_pts"].shape[0]): 104 | gt_default_map=self.pt2dmap(targets["gt_default_pts"][idx],targets["gt_default_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 105 | gt_duplicate_map=self.pt2dmap(targets["gt_duplicate_pts"][idx],targets["gt_duplicate_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 106 | gt_fuse_map=self.pt2dmap(targets["gt_fuse_pts"][idx],targets["gt_fuse_num"][idx],targets["h"][idx],targets["w"][idx]).to(device) 107 | gt_default_maps.append(gt_default_map) 108 | gt_duplicate_maps.append(gt_duplicate_map) 109 | gt_fuse_maps.append(gt_fuse_map) 110 | gt_default_maps=torch.stack(gt_default_maps,dim=0) 111 | gt_duplicate_maps=torch.stack(gt_duplicate_maps,dim=0) 112 | gt_fuse_maps=torch.stack(gt_fuse_maps,dim=0) 113 | 114 | loss_dict["default"] = map_loss(out_dict["default_counting_map"],gt_default_maps) 115 | loss_dict["duplicate"] = map_loss(out_dict["duplicate_counting_map"],gt_duplicate_maps) 116 | loss_dict["fuse"] = map_loss(out_dict["fuse_counting_map"],gt_fuse_maps) 117 | loss_dict["all"] = loss_dict["default"] + loss_dict["duplicate"] + loss_dict["fuse"] 118 | return loss_dict 119 | 120 | def build_model(): 121 | return Model() 122 | -------------------------------------------------------------------------------- /models/roi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | for i in range(num_feature_levels): 24 | input_proj_list.append(nn.Sequential( 25 | nn.Conv2d(num_channels[i], hidden_dim, 26 | kernel_size=1, stride=1, padding=0), 27 | nn.GELU(), 28 | )) 29 | for j in range(num_feature_levels-1-i): 30 | input_proj_list[i] += nn.Sequential( 31 | nn.Conv2d(hidden_dim, hidden_dim, 32 | kernel_size=3, stride=2, padding=1), 33 | nn.GELU(), 34 | ) 35 | self.input_fuse = nn.Sequential( 36 | nn.Conv2d(hidden_dim*num_feature_levels, hidden_dim * 37 | num_feature_levels, kernel_size=1, stride=1, padding=0), 38 | nn.ReLU(), 39 | ) 40 | self.fc=nn.Sequential( 41 | nn.Linear(hidden_dim*num_feature_levels*8*8,hidden_dim*num_feature_levels), 42 | nn.ReLU(), 43 | nn.Linear(hidden_dim*num_feature_levels,hidden_dim), 44 | nn.ReLU(), 45 | nn.Linear(hidden_dim,hidden_dim*num_feature_levels), 46 | nn.ReLU(), 47 | ) 48 | self.input_proj = nn.ModuleList(input_proj_list) 49 | self.ot_loss = GML() 50 | 51 | def forward(self, input): 52 | x = input["image_pair"] 53 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 54 | x1 = x[:, 0:3, :, :] 55 | x2 = x[:, 3:6, :, :] 56 | ref_point1 = ref_points[:, 0, ...] 57 | ref_point2 = ref_points[:, 1, ...] 58 | z2 = self.backbone(x2) 59 | z2_list = [] 60 | for i in range(len(z2)): 61 | z2_list.append(self.input_proj[i](z2[i])) 62 | z2 = torch.cat(z2_list, dim=1) 63 | z2 = self.input_fuse(z2) 64 | features_2 = [] 65 | for pt in ref_point2[0]: 66 | features_2.append(self.get_feature(z2, pt)) 67 | z1 = self.backbone(x1) 68 | z1_list = [] 69 | for i in range(len(z1)): 70 | z1_list.append(self.input_proj[i](z1[i])) 71 | 72 | z1 = torch.cat(z1_list, dim=1) 73 | z1 = self.input_fuse(z1) 74 | 75 | features_1 = [] 76 | 77 | for pt in ref_point1[0]: 78 | features_1.append(self.get_feature(z1, pt)) 79 | 80 | features_1 = torch.stack(features_1, dim=1).flatten(2) 81 | features_2 = torch.stack(features_2, dim=1).flatten(2) 82 | return features_1, features_2 83 | 84 | def get_feature(self, z, pt, window_size=8): 85 | h, w = z.shape[-2], z.shape[-1] 86 | x_min = pt[0]*w-window_size//2 87 | x_max = pt[0]*w+window_size//2 88 | y_min = pt[1]*h-window_size//2 89 | y_max = pt[1]*h+window_size//2 90 | x_min, x_max, y_min, y_max = int(x_min), int( 91 | x_max), int(y_min), int(y_max) 92 | z = z[:, :, y_min:y_max, x_min:x_max] 93 | x_pad_left = 0 94 | x_pad_right = window_size-z.shape[-1] 95 | y_pad_top = 0 96 | y_pad_bottom = window_size-z.shape[-2] 97 | z = F.pad(z, (x_pad_left, x_pad_right, y_pad_top, y_pad_bottom)) 98 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 99 | # z = z + pos_emb 100 | # z=F.adaptive_avg_pool2d(z,(1,1)) 101 | # z=z.squeeze(3).squeeze(2) 102 | z = z.flatten(2).flatten(1) 103 | return z 104 | 105 | def loss(self, features1, features2): 106 | loss = self.ot_loss(features1, features2) 107 | loss_dict = {} 108 | loss_dict["all"] = loss 109 | return loss_dict 110 | def pos2posemb2d(self, pos, num_pos_feats=128, temperature=1000): 111 | scale = 2 * math.pi 112 | pos = pos * scale 113 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 114 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 115 | pos_x = pos[..., 0, None] / dim_t 116 | pos_y = pos[..., 1, None] / dim_t 117 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 118 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 119 | posemb = torch.cat((pos_y, pos_x), dim=-1) 120 | return posemb 121 | 122 | 123 | def build_model(): 124 | model = Model() 125 | return model 126 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | import math 8 | 9 | 10 | 11 | def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): 12 | scale = 2 * math.pi 13 | pos = pos * scale 14 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 15 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 16 | pos_x = pos[..., 0, None] / dim_t 17 | pos_y = pos[..., 1, None] / dim_t 18 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 19 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 20 | posemb = torch.cat((pos_y, pos_x), dim=-1) 21 | return posemb 22 | 23 | class Transformer(nn.Module): 24 | 25 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 26 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 27 | activation="relu", normalize_before=False, 28 | return_intermediate_dec=False): 29 | super().__init__() 30 | 31 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 34 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 35 | 36 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 37 | dropout, activation, normalize_before) 38 | decoder_norm = nn.LayerNorm(d_model) 39 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 40 | return_intermediate=return_intermediate_dec) 41 | 42 | self._reset_parameters() 43 | 44 | self.d_model = d_model 45 | self.nhead = nhead 46 | 47 | def _reset_parameters(self): 48 | for p in self.parameters(): 49 | if p.dim() > 1: 50 | nn.init.xavier_uniform_(p) 51 | 52 | def forward(self, src, mask, query_embed, pos_embed): 53 | # flatten NxCxHxW to HWxNxC 54 | bs, c, h, w = src.shape 55 | src = src.flatten(2).permute(2, 0, 1) 56 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 57 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 58 | mask = mask.flatten(1) 59 | 60 | tgt = torch.zeros_like(query_embed) 61 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 62 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 63 | pos=pos_embed, query_pos=query_embed) 64 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 65 | 66 | 67 | class TransformerEncoder(nn.Module): 68 | 69 | def __init__(self, encoder_layer, num_layers, norm=None): 70 | super().__init__() 71 | self.layers = _get_clones(encoder_layer, num_layers) 72 | self.num_layers = num_layers 73 | self.norm = norm 74 | 75 | def forward(self, src, 76 | mask: Optional[Tensor] = None, 77 | src_key_padding_mask: Optional[Tensor] = None, 78 | pos: Optional[Tensor] = None): 79 | output = src 80 | 81 | for layer in self.layers: 82 | output = layer(output, src_mask=mask, 83 | src_key_padding_mask=src_key_padding_mask, pos=pos) 84 | 85 | if self.norm is not None: 86 | output = self.norm(output) 87 | 88 | return output 89 | 90 | 91 | class TransformerDecoder(nn.Module): 92 | 93 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 94 | super().__init__() 95 | self.layers = _get_clones(decoder_layer, num_layers) 96 | self.num_layers = num_layers 97 | self.norm = norm 98 | self.return_intermediate = return_intermediate 99 | 100 | def forward(self, tgt, memory, 101 | tgt_mask: Optional[Tensor] = None, 102 | memory_mask: Optional[Tensor] = None, 103 | tgt_key_padding_mask: Optional[Tensor] = None, 104 | memory_key_padding_mask: Optional[Tensor] = None, 105 | pos: Optional[Tensor] = None, 106 | query_pos: Optional[Tensor] = None): 107 | output = tgt 108 | 109 | intermediate = [] 110 | 111 | for layer in self.layers: 112 | output = layer(output, memory, tgt_mask=tgt_mask, 113 | memory_mask=memory_mask, 114 | tgt_key_padding_mask=tgt_key_padding_mask, 115 | memory_key_padding_mask=memory_key_padding_mask, 116 | pos=pos, query_pos=query_pos) 117 | if self.return_intermediate: 118 | intermediate.append(self.norm(output)) 119 | 120 | if self.norm is not None: 121 | output = self.norm(output) 122 | if self.return_intermediate: 123 | intermediate.pop() 124 | intermediate.append(output) 125 | 126 | if self.return_intermediate: 127 | return torch.stack(intermediate) 128 | 129 | return output.unsqueeze(0) 130 | 131 | 132 | class TransformerEncoderLayer(nn.Module): 133 | 134 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 135 | activation="relu", normalize_before=False): 136 | super().__init__() 137 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 138 | # Implementation of Feedforward model 139 | self.linear1 = nn.Linear(d_model, dim_feedforward) 140 | self.dropout = nn.Dropout(dropout) 141 | self.linear2 = nn.Linear(dim_feedforward, d_model) 142 | 143 | self.norm1 = nn.LayerNorm(d_model) 144 | self.norm2 = nn.LayerNorm(d_model) 145 | self.dropout1 = nn.Dropout(dropout) 146 | self.dropout2 = nn.Dropout(dropout) 147 | 148 | self.activation = _get_activation_fn(activation) 149 | self.normalize_before = normalize_before 150 | 151 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 152 | return tensor if pos is None else tensor + pos 153 | 154 | def forward_post(self, 155 | src, 156 | src_mask: Optional[Tensor] = None, 157 | src_key_padding_mask: Optional[Tensor] = None, 158 | pos: Optional[Tensor] = None): 159 | q = k = self.with_pos_embed(src, pos) 160 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 161 | key_padding_mask=src_key_padding_mask)[0] 162 | src = src + self.dropout1(src2) 163 | src = self.norm1(src) 164 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 165 | src = src + self.dropout2(src2) 166 | src = self.norm2(src) 167 | return src 168 | 169 | def forward_pre(self, src, 170 | src_mask: Optional[Tensor] = None, 171 | src_key_padding_mask: Optional[Tensor] = None, 172 | pos: Optional[Tensor] = None): 173 | src2 = self.norm1(src) 174 | q = k = self.with_pos_embed(src2, pos) 175 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 176 | key_padding_mask=src_key_padding_mask)[0] 177 | src = src + self.dropout1(src2) 178 | src2 = self.norm2(src) 179 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 180 | src = src + self.dropout2(src2) 181 | return src 182 | 183 | def forward(self, src, 184 | src_mask: Optional[Tensor] = None, 185 | src_key_padding_mask: Optional[Tensor] = None, 186 | pos: Optional[Tensor] = None): 187 | if self.normalize_before: 188 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 189 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 190 | 191 | 192 | class TransformerDecoderLayer(nn.Module): 193 | 194 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 195 | activation="relu", normalize_before=False): 196 | super().__init__() 197 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 198 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 199 | # Implementation of Feedforward model 200 | self.linear1 = nn.Linear(d_model, dim_feedforward) 201 | self.dropout = nn.Dropout(dropout) 202 | self.linear2 = nn.Linear(dim_feedforward, d_model) 203 | 204 | self.norm1 = nn.LayerNorm(d_model) 205 | self.norm2 = nn.LayerNorm(d_model) 206 | self.norm3 = nn.LayerNorm(d_model) 207 | self.dropout1 = nn.Dropout(dropout) 208 | self.dropout2 = nn.Dropout(dropout) 209 | self.dropout3 = nn.Dropout(dropout) 210 | 211 | self.activation = _get_activation_fn(activation) 212 | self.normalize_before = normalize_before 213 | 214 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 215 | return tensor if pos is None else tensor + pos 216 | 217 | def forward_post(self, tgt, memory, 218 | tgt_mask: Optional[Tensor] = None, 219 | memory_mask: Optional[Tensor] = None, 220 | tgt_key_padding_mask: Optional[Tensor] = None, 221 | memory_key_padding_mask: Optional[Tensor] = None, 222 | pos: Optional[Tensor] = None, 223 | query_pos: Optional[Tensor] = None): 224 | q = k = self.with_pos_embed(tgt, query_pos) 225 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 226 | key_padding_mask=tgt_key_padding_mask)[0] 227 | tgt = tgt + self.dropout1(tgt2) 228 | tgt = self.norm1(tgt) 229 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 230 | key=self.with_pos_embed(memory, pos), 231 | value=memory, attn_mask=memory_mask, 232 | key_padding_mask=memory_key_padding_mask)[0] 233 | tgt = tgt + self.dropout2(tgt2) 234 | tgt = self.norm2(tgt) 235 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 236 | tgt = tgt + self.dropout3(tgt2) 237 | tgt = self.norm3(tgt) 238 | return tgt 239 | 240 | def forward_pre(self, tgt, memory, 241 | tgt_mask: Optional[Tensor] = None, 242 | memory_mask: Optional[Tensor] = None, 243 | tgt_key_padding_mask: Optional[Tensor] = None, 244 | memory_key_padding_mask: Optional[Tensor] = None, 245 | pos: Optional[Tensor] = None, 246 | query_pos: Optional[Tensor] = None): 247 | tgt2 = self.norm1(tgt) 248 | q = k = self.with_pos_embed(tgt2, query_pos) 249 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 250 | key_padding_mask=tgt_key_padding_mask)[0] 251 | tgt = tgt + self.dropout1(tgt2) 252 | tgt2 = self.norm2(tgt) 253 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 254 | key=self.with_pos_embed(memory, pos), 255 | value=memory, attn_mask=memory_mask, 256 | key_padding_mask=memory_key_padding_mask)[0] 257 | tgt = tgt + self.dropout2(tgt2) 258 | tgt2 = self.norm3(tgt) 259 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 260 | tgt = tgt + self.dropout3(tgt2) 261 | return tgt 262 | 263 | def forward(self, tgt, memory, 264 | tgt_mask: Optional[Tensor] = None, 265 | memory_mask: Optional[Tensor] = None, 266 | tgt_key_padding_mask: Optional[Tensor] = None, 267 | memory_key_padding_mask: Optional[Tensor] = None, 268 | pos: Optional[Tensor] = None, 269 | query_pos: Optional[Tensor] = None): 270 | if self.normalize_before: 271 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 272 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 273 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 274 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 275 | 276 | 277 | def _get_clones(module, N): 278 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 279 | 280 | 281 | def build_transformer(args): 282 | return Transformer( 283 | d_model=args.hidden_dim, 284 | dropout=args.dropout, 285 | nhead=args.nheads, 286 | dim_feedforward=args.dim_feedforward, 287 | num_encoder_layers=args.enc_layers, 288 | num_decoder_layers=args.dec_layers, 289 | normalize_before=args.pre_norm, 290 | return_intermediate_dec=True, 291 | ) 292 | 293 | 294 | def _get_activation_fn(activation): 295 | """Return an activation function given a string""" 296 | if activation == "relu": 297 | return F.relu 298 | if activation == "gelu": 299 | return F.gelu 300 | if activation == "glu": 301 | return F.glu 302 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /models/tri_cropper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .backbone import build_backbone 5 | from .tri_sim_ot_b import GML 6 | import math 7 | import copy 8 | from typing import Optional, List, Dict, Tuple, Set, Union, Iterable, Any 9 | from torch import Tensor 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, stride=16, num_feature_levels=3, num_channels=[96, 192, 384], hidden_dim=256, freeze_backbone=False) -> None: 14 | super().__init__() 15 | self.backbone = build_backbone() 16 | if freeze_backbone: 17 | for name, parameter in self.backbone.named_parameters(): 18 | parameter.requires_grad_(False) 19 | self.stride = stride 20 | self.num_feature_levels = num_feature_levels 21 | self.num_channels = num_channels 22 | input_proj_list = [] 23 | 24 | self.ot_loss = GML() 25 | 26 | def forward(self, input): 27 | x = input["image_pair"] 28 | ref_points = input["ref_pts"][:, :, :input["ref_num"], :] 29 | ind_points0=input["independ_pts0"][:,:input["independ_num0"],...] 30 | ind_points1=input["independ_pts1"][:,:input["independ_num1"],...] 31 | x1 = x[:, 0:3, :, :] 32 | x2 = x[:, 3:6, :, :] 33 | ref_point1 = ref_points[:, 0, ...] 34 | ref_point2 = ref_points[:, 1, ...] 35 | 36 | z1_lists=[] 37 | z2_lists=[] 38 | y1_lists=[] 39 | y2_lists=[] 40 | for b in range(x1.shape[0]): 41 | z1_list=[] 42 | z2_list=[] 43 | y1_list=[] 44 | y2_list=[] 45 | for pt1,pt2 in zip(ref_point1[b],ref_point2[b]): 46 | z1=self.get_crops(x1[b].unsqueeze(0),pt1) 47 | z2=self.get_crops(x2[b].unsqueeze(0),pt2) 48 | z1=F.interpolate(z1,(224,224)) 49 | z2=F.interpolate(z2,(224,224)) 50 | z1_list.append(z1) 51 | z2_list.append(z2) 52 | z1=torch.cat(z1_list,dim=0) 53 | z2=torch.cat(z2_list,dim=0) 54 | z1=self.backbone(z1)[0].flatten(2).flatten(1) 55 | z2=self.backbone(z2)[0].flatten(2).flatten(1) 56 | z1_lists.append(z1) 57 | z2_lists.append(z2) 58 | for pt in ind_points0[b]: 59 | y1=self.get_crops(x1[b].unsqueeze(0),pt) 60 | y1=F.interpolate(y1,(224,224)) 61 | y1_list.append(y1) 62 | for pt in ind_points1[b]: 63 | y2=self.get_crops(x2[b].unsqueeze(0),pt) 64 | y2=F.interpolate(y2,(224,224)) 65 | y2_list.append(y2) 66 | y1=torch.cat(y1_list,dim=0) 67 | y2=torch.cat(y2_list,dim=0) 68 | y1=self.backbone(y1)[0].flatten(2).flatten(1) 69 | y2=self.backbone(y2)[0].flatten(2).flatten(1) 70 | y1_lists.append(y1) 71 | y2_lists.append(y2) 72 | y1=torch.stack(y1_lists,dim=0) 73 | y2=torch.stack(y2_lists,dim=0) 74 | z1=torch.stack(z1_lists,dim=0) 75 | z2=torch.stack(z2_lists,dim=0) 76 | z1=F.normalize(z1,dim=-1) 77 | z2=F.normalize(z2,dim=-1) 78 | y1=F.normalize(y1,dim=-1) 79 | y2=F.normalize(y2,dim=-1) 80 | return z1,z2,y1,y2 81 | def forward_single_image(self,x,pts,absolute=False): 82 | z_lists=[] 83 | max_batch=20 84 | for b in range(x.shape[0]): 85 | z_list=[] 86 | z_lists.append([]) 87 | for pt in pts[b]: 88 | z=self.get_crops(x[b].unsqueeze(0),pt,absolute=absolute) 89 | z=F.interpolate(z,(224,224)) 90 | z_list.append(z) 91 | z_list=[z_list[i:i+max_batch] for i in range(0,len(z_list),max_batch)] 92 | 93 | for z in z_list: 94 | 95 | z=torch.cat(z,dim=0) 96 | 97 | z=self.backbone(z)[0].flatten(2).flatten(1) 98 | z_lists[-1].append(z) 99 | z_lists[-1]=torch.cat(z_lists[-1],dim=0) 100 | z=torch.stack(z_lists,dim=0) 101 | z=F.normalize(z,dim=-1) 102 | return z 103 | def get_crops(self, z, pt, window_size=[32,32,32,64],absolute=False): 104 | h, w = z.shape[-2], z.shape[-1] 105 | if absolute: 106 | x_min = pt[0]-window_size[0] 107 | x_max = pt[0]+window_size[1] 108 | y_min = pt[1]-window_size[2] 109 | y_max = pt[1]+window_size[3] 110 | else: 111 | 112 | x_min = pt[0]*w-window_size[0] 113 | x_max = pt[0]*w+window_size[1] 114 | y_min = pt[1]*h-window_size[2] 115 | y_max = pt[1]*h+window_size[3] 116 | x_min, x_max, y_min, y_max = int(x_min), int( 117 | x_max), int(y_min), int(y_max) 118 | x_min = max(0, x_min) 119 | x_max = min(w, x_max) 120 | y_min = max(0, y_min) 121 | y_max = min(h, y_max) 122 | z = z[..., y_min:y_max, x_min:x_max] 123 | # pos_emb = self.pos2posemb2d(pt, num_pos_feats=z.shape[1]//2).unsqueeze(0).unsqueeze(2).unsqueeze(3) 124 | # z = z + pos_emb 125 | # z=F.adaptive_avg_pool2d(z,(1,1)) 126 | # z=z.squeeze(3).squeeze(2) 127 | return z 128 | 129 | def loss(self, z1,z2,y1,y2): 130 | loss_dict = self.ot_loss([z1,y1], [z2,y2]) 131 | # loss = self.ot_loss(z1, z2) 132 | 133 | loss_dict["all"] = loss_dict["scon_cost"]+loss_dict["hinge_cost"]*0.1 134 | return loss_dict 135 | 136 | 137 | def build_model(): 138 | model = Model() 139 | return model 140 | -------------------------------------------------------------------------------- /models/tri_sim_ot_b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from geomloss import SamplesLoss 5 | from scipy.optimize import linear_sum_assignment 6 | from sklearn.manifold import TSNE 7 | import scienceplots 8 | def similarity_cost(x1, x2, gamma=10): 9 | ''' 10 | :param x1: [B N D] 11 | :param x2: [B M D] 12 | :return: [B N M] 13 | ''' 14 | N = x1.shape[1] 15 | x1 = F.normalize(x1, dim=-1) 16 | x2 = F.normalize(x2, dim=-1) 17 | sim = torch.bmm(x1, x2.transpose(1, 2)) * gamma 18 | sim = sim.exp() 19 | sim_c = sim.sum(dim=1, keepdim=True) 20 | sim_r = sim.sum(dim=2, keepdim=True) 21 | sim = sim / (sim_c + sim_r - sim) 22 | return 1 - sim 23 | 24 | def get_group_ones(N, b, B, device): 25 | alpha = [] 26 | for i in range(B): 27 | if i == b: 28 | alpha.append(torch.ones(1, N[i], device=device).float()) 29 | else: 30 | alpha.append(torch.zeros(1, N[i], device=device).float()) 31 | 32 | alpha = torch.cat(alpha, dim=1) 33 | return alpha 34 | 35 | class GML(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.ot = SamplesLoss(backend="tensorized", cost=similarity_cost, debias=False,diameter=3) 39 | self.margin=0.1 40 | def forward(self, x_pre, x_cur): 41 | ''' 42 | :param x1: B tensor of size [N_b D] 43 | :param x2: B tensor of size [N_b D] 44 | :param x3: B tensor of size [N1_b D] negative exapmles 45 | :param x4: B tensor of size [M1_b D] negative exapmles 46 | :return: ot loss 47 | ''' 48 | x1, x3 = x_pre 49 | x2, x4 = x_cur 50 | B = len(x1) # assert B == 1 51 | N = [x.shape[0] for x in x1] 52 | M = [x.shape[0] for x in x2] 53 | N1 = [x.shape[0] for x in x3] 54 | M1 = [x.shape[0] for x in x4] 55 | 56 | device = x1[0].device 57 | ot_loss = [] 58 | 59 | for b in range(B): 60 | assert N[b] == M[b] 61 | alpha = torch.cat((torch.ones(1, N[b], device=device), torch.zeros(1, N1[b], device=device)), dim=1) 62 | beta = torch.cat((torch.ones(1, N[b], device=device), torch.zeros(1, M1[b], device=device)), dim=1) 63 | f1 = torch.cat((x1[b], x3[b]), dim=0).unsqueeze(0) 64 | f2 = torch.cat((x2[b], x4[b]), dim=0).unsqueeze(0) 65 | loss = self.ot(alpha, f1, beta, f2) 66 | ot_loss.append(loss) 67 | loss=0 68 | loss+=torch.relu(self.margin-torch.bmm(x3,x4.transpose(1,2))).sum() 69 | loss_dict={ 70 | "hinge_cost":loss.unsqueeze(0), 71 | "scon_cost":sum(ot_loss) / len(ot_loss) 72 | } 73 | return loss_dict 74 | 75 | 76 | if __name__ == "__main__": 77 | # x1 = torch.ones(2, 2) 78 | # x2 = torch.ones(2, 2) 79 | from matplotlib import pyplot as plt 80 | Dim = 512 81 | 82 | x1 = torch.randn(1,25, Dim) 83 | x2 = torch.randn(1,25, Dim) 84 | x1.requires_grad = True 85 | x2.requires_grad = True 86 | 87 | x3 = torch.randn(1,6, Dim) 88 | x4 = torch.randn(1,9, Dim) 89 | x3.requires_grad = True 90 | x4.requires_grad = True 91 | 92 | # x1 = [torch.eye(10)] 93 | # x2 = [torch.eye(10)] 94 | 95 | lr = 0.5 96 | loss = GML() 97 | for i in range(50): 98 | x1=F.normalize(x1,dim=-1) 99 | x2=F.normalize(x2,dim=-1) 100 | x3=F.normalize(x3,dim=-1) 101 | x4=F.normalize(x4,dim=-1) 102 | 103 | l = loss([x1, x3], [x2, x4]) 104 | t=l["scon_cost"]+0.2*l["hinge_cost"] 105 | if (i+1) % 1 == 0 or i == 0: 106 | 107 | s1=torch.bmm(x1,x2.transpose(1,2)) 108 | s2= torch.bmm(x1,x4.transpose(1,2)) 109 | s3 = torch.bmm(x3,x2.transpose(1,2)) 110 | s4 = torch.bmm(x3,x4.transpose(1,2)) 111 | s1_s2=torch.cat((s1,s2),dim=2) 112 | s3_s4=torch.cat((s3,s4),dim=2) 113 | img=torch.cat((s1_s2,s3_s4),dim=1) 114 | 115 | [g1, g2, g3, g4] = torch.autograd.grad(t, [x1, x2, x3, x4]) 116 | x1.data -= lr * g1 117 | x1.data = F.relu(x1.data) 118 | x2.data -= lr * g2 119 | x2.data = F.relu(x2.data) 120 | x3.data -= lr * g3 121 | x3.data = F.relu(x3.data) 122 | x4.data -= lr * g4 123 | x4.data = F.relu(x4.data) 124 | 125 | pos_sim_cost=1- similarity_cost(x1, x2) 126 | neg_sim_cost=1- similarity_cost(x1, x4) 127 | all_cost=1- similarity_cost(torch.cat((x1,x3),dim=1), torch.cat((x2,x4),dim=1)) 128 | print("pos_sim_cost") 129 | print(pos_sim_cost.cpu().detach().numpy()[0]) 130 | print("neg_sim_cost") 131 | print(neg_sim_cost.cpu().detach().numpy()[0]) 132 | print("all_cost") 133 | print(all_cost.cpu().detach().numpy()[0]) 134 | x1=F.normalize(x1,dim=-1) 135 | x2=F.normalize(x2,dim=-1) 136 | x3=F.normalize(x3,dim=-1) 137 | x4=F.normalize(x4,dim=-1) 138 | match_matrix=torch.bmm(torch.cat((x1,x3),dim=1),torch.cat((x2,x4),dim=1).transpose(1,2)) 139 | 140 | print("match_matrix") 141 | print(match_matrix.cpu().detach().numpy()[0]) 142 | match=linear_sum_assignment(1-match_matrix.cpu().detach().numpy()[0]) 143 | print("match") 144 | print(match) 145 | print("x1") 146 | print(x1.cpu().detach().numpy()) 147 | print("x2") 148 | print(x2.cpu().detach().numpy()) 149 | print("x3") 150 | print(x3.cpu().detach().numpy()) 151 | print("x4") 152 | print(x4.cpu().detach().numpy()) -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer_builder import optimizer_builder, optimizer_finetune_builder 2 | from .scheduler_builder import scheduler_builder -------------------------------------------------------------------------------- /optimizer/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def optimizer_builder(args,model_without_ddp): 4 | params = [ 5 | { 6 | "params": 7 | [p for n, p in model_without_ddp.named_parameters() 8 | if p.requires_grad], 9 | "lr": args.lr, 10 | }, 11 | ] 12 | if args.type.lower()=="sgd": 13 | return torch.optim.SGD(params, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 14 | elif args.type.lower()=="adam": 15 | return torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) 16 | elif args.type.lower()=="adamw": 17 | return torch.optim.AdamW(params, args.lr, weight_decay=args.weight_decay) 18 | 19 | def optimizer_finetune_builder(args,model_without_ddp): 20 | param_backbone = [p for p in model_without_ddp.backbone.parameters() if p.requires_grad] + \ 21 | [p for p in model_without_ddp.decoder_layers.decoder.parameters() if p.requires_grad] 22 | param_predict = [p for p in model_without_ddp.decoder_layers.output_layer.parameters() if p.requires_grad] 23 | params = [ 24 | { 25 | "params": param_backbone, 26 | "lr": args.lr * 0.1 , 27 | }, 28 | { 29 | "params": param_predict, 30 | "lr": args.lr, 31 | } 32 | ] 33 | if args.type.lower()=="sgd": 34 | return torch.optim.SGD(params, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 35 | elif args.type.lower()=="adam": 36 | return torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) 37 | elif args.type.lower()=="adamw": 38 | return torch.optim.AdamW(params, args.lr, weight_decay=args.weight_decay) -------------------------------------------------------------------------------- /optimizer/scheduler_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def scheduler_builder(args,optimizer): 4 | print(optimizer) 5 | if args.type=="step": 6 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, args.gamma) 7 | elif args.type=="cosine": 8 | return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.T_max, args.eta_min) 9 | elif args.type=="multi_step": 10 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, args.gamma) 11 | elif args.type=="consine_warm_restart": 12 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.T_max, args.eta_min) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Video Individual Counting 2 | 3 | Official PyTorch implementation of "Weakly Supervised Video Individual Counting" as presented at CVPR 2024. 4 | 5 | 📄 [Read the Paper](https://openaccess.thecvf.com/content/CVPR2024/html/Liu_Weakly_Supervised_Video_Individual_Counting_CVPR_2024_paper.html) 6 | 7 | **Authors:** Xinyan Liu, Guorong Li, Yuankai Qi, Ziheng Yan, Zhenjun Han, Anton van den Hengel, Ming-Hsuan Yang, Qingming Huang 8 | 9 | ## Overview 10 | 11 | The Video Individual Counting (VIC) task focuses on predicting the count of unique individuals in videos. Traditional methods, which rely on costly individual trajectory annotations, are impractical for large-scale applications. This work introduces a novel approach to VIC under a weakly supervised framework, utilizing less restrictive inflow and outflow annotations. We propose a baseline method employing weakly supervised contrastive learning for group-level matching, enhanced by a custom soft contrastive loss, facilitating the distinction between different crowd dynamics. We also contribute two augmented datasets, SenseCrowd and CroHD, and introduce a new dataset, UAVVIC, to foster further research in this area. Our method demonstrates superior performance compared to fully supervised counterparts, making a strong case for its practical applicability. 12 | 13 | ## Inference Pipeline 14 | 15 | ![Inference Pipeline](./statics/imgs/pipeline.png) 16 | 17 | The CGNet architecture includes: 18 | - **Frame-level Crowd Locator**: Detects pedestrian coordinates. 19 | - **Encoder**: Generates unique representations for each detected individual. 20 | - **Memory-based Individual Count Predictor (MCP)**: Estimates inflow counts and maintains a memory of individual templates. 21 | 22 | The Weakly Supervised Representation Learning (WSRL) method utilizes both inflow and outflow labels to refine the encoder through a novel Group-Level Matching Loss (GML), integrating soft contrastive and hinge losses to optimize performance. 23 | 24 | ## Demo 25 | 26 | Our model processes video inputs to predict individual counts, operating over 3-second intervals. 27 | 28 |

29 | GIF 1 30 | GIF 2 31 | GIF 3 32 | GIF 4 33 |

34 | 35 | 36 | ## Setup 37 | 38 | ### Installation 39 | 40 | Clone and set up the CGNet repository: 41 | 42 | ```bash 43 | git clone https://github.com/streamer-AP/CGNet 44 | cd CGNet 45 | conda create -n CGNet python=3.10 46 | conda activate CGNet 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | Data Preparation 51 | - **CroHD** : Download CroHD dataset from this [link](https://motchallenge.net/data/Head_Tracking_21/). Unzip ```HT21.zip``` and place ``` HT21``` into the folder (```Root/dataset/```). 52 | - **SenseCrowd** dataset: Download the dataset from [Baidu disk](https://pan.baidu.com/s/1OYBSPxgwvRMrr6UTStq7ZQ?pwd=64xm) or from the original dataset [link](https://github.com/HopLee6/VSCrowd-Dataset). 53 | 54 | ### Usage 55 | 56 | 1. We provide a toy example for the GML loss to quickly understand GML loss, which can also be used in other tasks. 57 | 58 | ``` bash 59 | cd models/ 60 | python tri_sim_ot_b.py 61 | ``` 62 | You can see the similarity matrix converging process like this: 63 | 64 | ![sim](./statics/imgs/tinywow_sim_58930123.gif) 65 | 66 | 67 | 2. Inference. 68 | * Before inference, you need to get crowd localization result on a pre-trained crowd localization model. You can use [FIDTM](https://github.com/dk-liang/FIDTM.git), [STEERER](https://github.com/taohan10200/STEERER.git) or any other crowd localization model that output coordinates results. 69 | * We also provide a crowd localization results inferenced by FIDTM-HRNet-W48. You can download it from [Baidu disk](https://pan.baidu.com/s/1i9BXHab5pVYhZFCESD6F7Q?pwd=08zg) or [Google drive](https://drive.google.com/file/d/12cMTFTf_xEiE_AYOvs1CgG91EAvbD_ew/view?usp=drive_link). The data format follows: 70 | ``` 71 | x y 72 | x y 73 | ``` 74 | * Pretrained models can be downloaded from [Baidu disk](https://pan.baidu.com/s/1TQvU-5K8twDXGF_NIhRrAg)(pwd: jhux). or [Google drive](https://drive.google.com/file/d/1EcEy11HVMDxPUMztC-zSIQ4jAMUStbDG/view?usp=drive_link). Unzip it in the weight folder and run the following command. 75 | ``` bash 76 | python inference.py 77 | ``` 78 | * It will cost less than 2GB GPU memory. And using interval = 15(3s), the inference time will be less than 10 minutes for all datasets. 79 | When finished, it will generate a json file in the result folder. The data format follows: 80 | ```json 81 | { 82 | "video_name": { 83 | "video_num": the predicted vic, 84 | "first_frame_num": the predicted count in the first frame, 85 | "cnt_list": [the count of inflow in each frame], 86 | "pos_lists": [the position of each individual in each frame], 87 | "frame_num": the total frame number, 88 | "inflow_lists": [the inflow of each individual in each frame], 89 | }, 90 | ... 91 | } 92 | ``` 93 | * For the SenseCrowd dataset, the repoduced results of this repo is shown in results dir. 94 | Run the following command to evaluate the results. 95 | ``` bash 96 | python eval.py 97 | ``` 98 | * For MAE and WRAE, it is slightly better than the paper. The matrics are as follows: 99 | 100 | | Method | MAE | RMSE | WRAE| 101 | | ------ | --- | --- | --- | 102 | | Paper | 8.86 | 17.69| 12.6| 103 | | Repo | 8.64 | 18.70| 11.76| 104 | || 105 | 106 | 3. Training. 107 | * For training, you need to prepare the dataset and the crowd localization results. The data format follows: 108 | ``` 109 | x y 110 | x y 111 | ``` 112 | * The training script is as follows: 113 | ``` bash 114 | python train.py 115 | ``` 116 | * It also supports multi-GPU training. You can set the number of GPUs by using the following command: 117 | ``` bash 118 | bash dist_train.sh 8 119 | ``` 120 | 121 | ## Citation 122 | 123 | If you find this repository helpful, please cite our paper: 124 | 125 | ``` bash 126 | @inproceedings{liu2024weakly, 127 | title={Weakly Supervised Video Individual Counting}, 128 | author={Liu, Xinyan and Li, Guorong and Qi, Yuankai and Yan, Ziheng and Han, Zhenjun and van den Hengel, Anton and Yang, Ming-Hsuan and Huang, Qingming}, 129 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 130 | pages={19228--19237}, 131 | year={2024} 132 | } 133 | ``` 134 | 135 | ## Acknowledgement 136 | We thank the authors of [FIDTM](https://github.com/dk-liang/FIDTM.git) and [DR.VIC](https://github.com/taohan10200/DRNet.git) for their excellent work. 137 | 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.0 2 | easydict==1.10 3 | geomloss==0.2.6 4 | head==1.0.0 5 | matplotlib==3.7.1 6 | mmcv==2.0.0 7 | numpy==1.24.3 8 | opencv_python==4.7.0.72 9 | opencv_python_headless==4.7.0.72 10 | Pillow==9.4.0 11 | Pillow==10.2.0 12 | SciencePlots==2.1.0 13 | scikit_learn==1.2.2 14 | scipy==1.12.0 15 | termcolor==2.4.0 16 | timm==0.9.2 17 | torch==2.0.1 18 | torchvision==0.15.2 19 | utils==1.0.2 20 | -------------------------------------------------------------------------------- /statics/imgs/c3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/c3.gif -------------------------------------------------------------------------------- /statics/imgs/c4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/c4.gif -------------------------------------------------------------------------------- /statics/imgs/p2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/p2.gif -------------------------------------------------------------------------------- /statics/imgs/p3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/p3.gif -------------------------------------------------------------------------------- /statics/imgs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/pipeline.png -------------------------------------------------------------------------------- /statics/imgs/sim.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/sim.mp4 -------------------------------------------------------------------------------- /statics/imgs/tinywow_sim_58930123.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/CGNet/a67e77de04df69f816f37dadaed83fe046967c4c/statics/imgs/tinywow_sim_58930123.gif -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import shutil 5 | import time 6 | from asyncio.log import logger 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from easydict import EasyDict as edict 11 | from termcolor import cprint 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from tri_dataset import build_dataset 17 | # from eingine.densemap_trainer import evaluate_counting, train_one_epoch 18 | from tri_eingine import evaluate_similarity, train_one_epoch 19 | from misc import tools 20 | from misc.saver_builder import Saver 21 | from misc.tools import MetricLogger, is_main_process 22 | from models.tri_cropper import build_model 23 | from optimizer import optimizer_builder, scheduler_builder 24 | from torch.nn import SyncBatchNorm 25 | from torch.cuda.amp import GradScaler 26 | torch.backends.cudnn.enabled = True 27 | torch.backends.cudnn.benchmark = True 28 | 29 | 30 | def main(args): 31 | tools.init_distributed_mode(args) 32 | tools.set_randomseed(42 + tools.get_rank()) 33 | 34 | # initilize the model 35 | model = model_without_ddp = build_model() 36 | model.cuda() 37 | if args.distributed: 38 | sync_model=SyncBatchNorm.convert_sync_batchnorm(model) 39 | model = torch.nn.parallel.DistributedDataParallel( 40 | sync_model, device_ids=[args.gpu], find_unused_parameters=False) 41 | model_without_ddp = model.module 42 | 43 | # build the dataset and dataloader 44 | dataset_train = build_dataset(args.Dataset.train.root,args.Dataset.val.ann_dir, args.Dataset.val.max_len,train=True,step=15) 45 | dataset_val = build_dataset(args.Dataset.val.root,args.Dataset.val.ann_dir, args.Dataset.val.max_len,train=True,step=15) 46 | # dataset_train = build_dataset(args.Dataset.train.root,args.Dataset.val.ann_dir, args.Dataset.val.max_len) 47 | # dataset_val = build_dataset(args.Dataset.val.root,args.Dataset.val.ann_dir, args.Dataset.val.max_len) 48 | sampler_train = DistributedSampler(dataset_train) if args.distributed else None 49 | sampler_val = DistributedSampler(dataset_val, shuffle=False) if args.distributed else None 50 | loader_train = DataLoader(dataset_train, 51 | batch_size=args.Dataset.train.batch_size, 52 | sampler=sampler_train, 53 | shuffle=False, 54 | num_workers=args.Dataset.train.num_workers, 55 | pin_memory=True) 56 | loader_val = DataLoader(dataset_val, 57 | batch_size=args.Dataset.val.batch_size, 58 | sampler=sampler_val, 59 | shuffle=False, 60 | num_workers=args.Dataset.val.num_workers, 61 | pin_memory=True) 62 | 63 | optimizer = optimizer_builder(args.Optimizer, model_without_ddp) 64 | 65 | scheduler = scheduler_builder(args.Scheduler, optimizer) 66 | saver = Saver(args.Saver) 67 | logger = MetricLogger(args.Logger) 68 | if is_main_process(): 69 | if args.Misc.use_tensorboard: 70 | tensorboard_writer = SummaryWriter(args.Misc.tensorboard_dir) 71 | scaler = GradScaler() 72 | for epoch in range(args.Misc.epochs): 73 | if args.distributed: 74 | sampler_train.set_epoch(epoch) 75 | stats = edict() 76 | stats.train_stats = train_one_epoch(model, loader_train, 77 | optimizer, logger,scaler, epoch, 78 | args) 79 | train_log_stats = { 80 | **{f'train_{k}': v 81 | for k, v in stats.train_stats.items()}, 'epoch': epoch 82 | } 83 | scheduler.step() 84 | if is_main_process(): 85 | for key, value in train_log_stats.items(): 86 | cprint(f'{key}:{value}', 'green') 87 | if args.Misc.use_tensorboard: 88 | tensorboard_writer.add_scalar(key, value, epoch) 89 | if epoch % args.Misc.val_freq == 0: 90 | stats.test_stats = evaluate_similarity(model, loader_val, logger, epoch, args) 91 | 92 | saver.save_on_master(model, optimizer, scheduler, epoch, stats) 93 | test_log_stats = { 94 | **{f'val_{k}': v 95 | for k, v in stats.test_stats.items()}, 'epoch': epoch 96 | } 97 | if is_main_process(): 98 | for key,value in test_log_stats.items(): 99 | cprint(f'{key}:{value}', 'green') 100 | if args.Misc.use_tensorboard: 101 | tensorboard_writer.add_scalar(key, value, epoch) 102 | else: 103 | saver.save_inter(model, optimizer, scheduler, f"checkpoint{epoch:04}.pth", epoch, stats) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser("DenseMap Head ") 108 | parser.add_argument("--config", default="configs/crowd_sense.json") 109 | parser.add_argument("--local_rank", type=int) 110 | args = parser.parse_args() 111 | 112 | if os.path.exists(args.config): 113 | with open(args.config, "r") as f: 114 | configs = json.load(f) 115 | cfg = edict(configs) 116 | print(cfg) 117 | 118 | strtime = time.strftime('%Y%m%d%H%M') + "_" + os.path.basename( 119 | args.config)[:-5] 120 | 121 | output_path = os.path.join(cfg.Misc.tensorboard_dir, strtime) 122 | if is_main_process(): 123 | if not os.path.exists(cfg.Misc.tensorboard_dir): 124 | os.mkdir(cfg.Misc.tensorboard_dir) 125 | if not os.path.exists(output_path): 126 | os.mkdir(output_path) 127 | # copy the config file to the output directory 128 | shutil.copy(args.config, output_path) 129 | 130 | cfg.Saver.save_dir = os.path.join(output_path, "checkpoints") 131 | cfg.Misc.tensorboard_dir = output_path 132 | 133 | main(cfg) 134 | -------------------------------------------------------------------------------- /tri_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | from torchvision.datasets import VisionDataset 3 | from PIL import Image 4 | import os 5 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 6 | import numpy as np 7 | import albumentations as A 8 | from albumentations.pytorch import ToTensorV2 9 | import torch 10 | import cv2 11 | def transform(): 12 | return A.Compose([ 13 | A.Resize(720,1280), 14 | A.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 15 | ToTensorV2(), 16 | ]) 17 | 18 | def video_transform(): 19 | return A.Compose([ 20 | A.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 21 | ToTensorV2(), 22 | ]) 23 | def inverse_normalize(img): 24 | img=img*torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) 25 | img=img+torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) 26 | img=img*255 27 | return img 28 | 29 | class VideoDataset(VisionDataset): 30 | def __init__(self, root: str, annotation_dir:str,transforms: Callable[..., Any] | None = None, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None) -> None: 31 | super().__init__(root, transforms, transform, target_transform) 32 | self.video_paths=os.listdir(root) 33 | self.video_paths=[os.path.join(root,video_path) for video_path in self.video_paths] 34 | self.annotations_paths=os.listdir(annotation_dir) 35 | self.annotations_paths=[os.path.join(annotation_dir,annotation) for annotation in self.annotations_paths] 36 | self.annotation={} 37 | self.videos=[] 38 | for annotation_path in self.annotations_paths: 39 | with open(annotation_path) as f: 40 | video_name=annotation_path.split('/')[-1].split('.')[0] 41 | 42 | self.annotation[video_name]=[] 43 | lines=f.readlines() 44 | for line in lines: 45 | line=line.split() 46 | file_name=line[0] 47 | width,height=int(line[1]),int(line[2]) 48 | 49 | data=[float(x) for x in line[3:] if x!=""] 50 | cnt=len(data)//7 51 | 52 | 53 | if len(data)>0: 54 | ids=-1*np.ones((cnt,1)) 55 | pts=-1*np.ones((cnt,2)) 56 | bboxes=-1*np.ones((cnt,4)) 57 | data=np.array(data) 58 | data=np.reshape(data,(-1,7)) 59 | ids=data[:,6].reshape(-1,1) 60 | pts=data[:,4:6] 61 | bboxes[:]=data[:,0:4] 62 | bboxes[:,2]=bboxes[:,2]-bboxes[:,0] 63 | bboxes[:,3]=bboxes[:,3]-bboxes[:,1] 64 | pts[:,0]=pts[:,0]/width 65 | pts[:,1]=pts[:,1]/height 66 | bboxes[:,0]=bboxes[:,0]/width 67 | bboxes[:,1]=bboxes[:,1]/height 68 | bboxes[:,2]=bboxes[:,2]/width 69 | bboxes[:,3]=bboxes[:,3]/height 70 | else: 71 | ids=-1*np.ones((1,1)) 72 | pts=-1*np.ones((1,2)) 73 | bboxes=-1*np.ones((1,4)) 74 | self.annotation[video_name].append({"file_name":file_name,"height":height,"width":width,"ids":ids,"pts":pts,"bboxes":bboxes,"cnt":cnt}) 75 | for video_path in self.video_paths: 76 | video_name=video_path.split('/')[-1] 77 | self.videos.append({ 78 | "video_name":video_name, 79 | "img_names":[], 80 | "height":self.annotation[video_name][0]["height"], 81 | "width":self.annotation[video_name][0]["width"], 82 | "ids":[], 83 | "pts":[], 84 | "bboxes":[], 85 | "cnt":[], 86 | }) 87 | for i in range(0,len(self.annotation[video_name])): 88 | self.videos[-1]["img_names"].append(self.annotation[video_name][i]["file_name"]) 89 | self.videos[-1]["ids"].append(self.annotation[video_name][i]["ids"]) 90 | self.videos[-1]["pts"].append(self.annotation[video_name][i]["pts"]) 91 | self.videos[-1]["bboxes"].append(self.annotation[video_name][i]["bboxes"]) 92 | self.videos[-1]["cnt"].append(self.annotation[video_name][i]["cnt"]) 93 | def __len__(self) -> int: 94 | return len(self.videos) 95 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 96 | video=self.videos[index] 97 | video_name=video["video_name"] 98 | img_names=video["img_names"] 99 | height=video["height"] 100 | width=video["width"] 101 | ids=video["ids"] 102 | pts=video["pts"] 103 | bboxes=video["bboxes"] 104 | cnt=video["cnt"] 105 | imgs=[] 106 | for img_name in img_names: 107 | img_path=os.path.join(self.root,video_name,img_name) 108 | img=self.transforms(image=np.array(Image.open(img_path).convert("RGB")))["image"] 109 | imgs.append(img) 110 | imgs=torch.stack(imgs,dim=0) 111 | labels={ 112 | "h":height, 113 | "w":width, 114 | "pts":pts, 115 | "bboxes":bboxes, 116 | "cnt":cnt, 117 | "video_name":video_name, 118 | "img_names":img_names, 119 | } 120 | return imgs,labels 121 | 122 | 123 | class PairDataset(VisionDataset): 124 | def __init__(self, root: str, annotation_dir:str, max_len:int ,transforms: Callable[..., Any] | None = None, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, train=True,step=20,interval=1,force_last=False) -> None: 125 | super().__init__(root, transforms, transform, target_transform) 126 | self.video_paths=os.listdir(root) 127 | self.video_paths=[os.path.join(root,video_path) for video_path in self.video_paths] 128 | self.annotations_paths=os.listdir(annotation_dir) 129 | self.annotations_paths=[os.path.join(annotation_dir,annotation) for annotation in self.annotations_paths] 130 | self.annotation={} 131 | self.pairs=[] 132 | self.max_len=max_len 133 | for annotation_path in self.annotations_paths: 134 | with open(annotation_path) as f: 135 | video_name=annotation_path.split('/')[-1].split('.')[0] 136 | 137 | self.annotation[video_name]=[] 138 | lines=f.readlines() 139 | for line in lines: 140 | line=line.split() 141 | file_name=line[0] 142 | width,height=int(line[1]),int(line[2]) 143 | 144 | data=[float(x) for x in line[3:] if x!=""] 145 | cnt=len(data)//7 146 | 147 | ids=-1*np.ones((max_len,1)) 148 | pts=-1*np.ones((max_len,2)) 149 | bboxes=-1*np.ones((max_len,4)) 150 | if len(data)>0: 151 | ids=-1*np.ones((cnt,1)) 152 | pts=-1*np.ones((cnt,2)) 153 | bboxes=-1*np.ones((cnt,4)) 154 | data=np.array(data) 155 | data=np.reshape(data,(-1,7)) 156 | ids=data[:,6].reshape(-1,1) 157 | pts=data[:,4:6] 158 | bboxes[:]=data[:,0:4] 159 | bboxes[:,2]=bboxes[:,2]-bboxes[:,0] 160 | bboxes[:,3]=bboxes[:,3]-bboxes[:,1] 161 | pts[:,0]=pts[:,0]/width 162 | pts[:,1]=pts[:,1]/height 163 | bboxes[:,0]=bboxes[:,0]/width 164 | bboxes[:,1]=bboxes[:,1]/height 165 | bboxes[:,2]=bboxes[:,2]/width 166 | bboxes[:,3]=bboxes[:,3]/height 167 | self.annotation[video_name].append({"file_name":file_name,"height":height,"width":width,"ids":ids,"pts":pts,"bboxes":bboxes,"cnt":cnt}) 168 | for video_path in self.video_paths: 169 | video_name=video_path.split('/')[-1] 170 | last_step=0 171 | for i in range(1,len(self.annotation[video_name])-step,interval): 172 | self.pairs.append({ 173 | "0":self.annotation[video_name][i], 174 | "1":self.annotation[video_name][i+step], 175 | "video_name":video_name, 176 | }) 177 | last_step=i+step 178 | if force_last and last_step int: 187 | return len(self.pairs) 188 | def add_noise(self,pts): 189 | noise=np.random.normal(scale=0.001,size=pts.shape) 190 | pts=pts+noise 191 | pts[pts>1]=1 192 | pts[pts<0]=0 193 | return pts 194 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 195 | pair=self.pairs[index] 196 | img0_path=os.path.join(self.root,pair["video_name"],pair["0"]["file_name"]) 197 | img1_path=os.path.join(self.root,pair["video_name"],pair["1"]["file_name"]) 198 | video_name=pair["video_name"] 199 | img_name1=pair["0"]["file_name"] 200 | img_name2=pair["1"]["file_name"] 201 | cnt_0=pair["0"]["cnt"] 202 | cnt_1=pair["1"]["cnt"] 203 | pt_0=pair["0"]["pts"] 204 | pt_1=pair["1"]["pts"] 205 | 206 | 207 | if self.train: 208 | pt_0=self.add_noise(pt_0) 209 | pt_1=self.add_noise(pt_1) 210 | bbox_0=pair["0"]["bboxes"] 211 | bbox_1=pair["1"]["bboxes"] 212 | id_0=pair["0"]["ids"] 213 | id_1=pair["1"]["ids"] 214 | if self.train and (pair["0"]["height"]!=pair["1"]["height"] or pair["0"]["width"]!=pair["1"]["width"] or cnt_0==0 or cnt_1==0): 215 | print("error") 216 | print(pair["0"]["height"],pair["1"]["height"],pair["0"]["width"],pair["1"]["width"],cnt_0,cnt_1) 217 | return self.__getitem__((index+1)%len(self)) 218 | img0=self.transforms(image=np.array(Image.open(img0_path).convert("RGB")))["image"] 219 | img1=self.transforms(image=np.array(Image.open(img1_path).convert("RGB")))["image"] 220 | fused_pts_list0=[] 221 | fused_pts_list1=[] 222 | id_list0=[] 223 | id_list1=[] 224 | fused_num=0 225 | fused_num1=0 226 | id0_pt0_dict={id[0]:pt for id,pt in zip(id_0,pt_0)} 227 | id1_pt1_dict={id[0]:pt for id,pt in zip(id_1,pt_1)} 228 | independ0_list=[] 229 | independ1_list=[] 230 | 231 | for pt,id in zip(pt_0,id_0): 232 | if id in id_1: 233 | fused_pts_list0.append(pt) 234 | id_list0.append(id) 235 | fused_num+=1 236 | pt1=id1_pt1_dict[id[0]] 237 | fused_pts_list1.append(pt1) 238 | id_list1.append(id) 239 | else: 240 | independ0_list.append(pt) 241 | for pt,id in zip(pt_1,id_1): 242 | if id not in id_0: 243 | independ1_list.append(pt) 244 | if len(independ0_list)==0: 245 | independ0_list.append((0.25,0.25)) 246 | if len(independ1_list)==0: 247 | independ1_list.append((0.75,0.75)) 248 | if self.train and (fused_num<2 or fused_num>30 or len(independ0_list)==0 or len(independ1_list)==0): 249 | return self.__getitem__((index+1)%len(self)) 250 | independ_pts0=-1*np.ones((self.max_len,2)) 251 | independ_pts0[:len(independ0_list)]=np.array(independ0_list) 252 | independ_pts1=-1*np.ones((self.max_len,2)) 253 | independ_pts1[:len(independ1_list)]=np.array(independ1_list) 254 | x=torch.cat([img0,img1],dim=0) 255 | fused_pts0=-1*np.ones((self.max_len,2)) 256 | fused_pts1=-1*np.ones((self.max_len,2)) 257 | 258 | if fused_num>0: 259 | fused_pts0[:fused_num]=np.array(fused_pts_list0) 260 | fused_pts1[:fused_num]=np.array(fused_pts_list1) 261 | 262 | fused_pts1=torch.from_numpy(fused_pts1).float() 263 | fused_pts0=torch.from_numpy(fused_pts0).float() 264 | # cv2_img0=inverse_normalize(img0).permute(1,2,0).detach().cpu().numpy().astype(np.uint8) 265 | # cv2_img1=inverse_normalize(img1).permute(1,2,0).detach().cpu().numpy().astype(np.uint8) 266 | # img_pair=np.concatenate([cv2_img0,cv2_img1],axis=1) 267 | # img_pair=cv2.cvtColor(img_pair,cv2.COLOR_RGB2BGR) 268 | # for pt0,pt1 in zip(fused_pts0,fused_pts1): 269 | # cv2.circle(img_pair,(int(pt0[0]*1280),int(pt0[1]*720)),5,(0,0,255),-1) 270 | # cv2.circle(img_pair,(int(pt1[0]*1280)+1280,int(pt1[1]*720)),5,(0,0,255),-1) 271 | # cv2.line(img_pair,(int(pt0[0]*1280),int(pt0[1]*720)),(int(pt1[0]*1280)+1280,int(pt1[1]*720)),(0,0,255),2) 272 | # cv2.imwrite(f"outputs/vision/{video_name}_{img_name1}_{img_name2}.jpg",img_pair) 273 | ref_pts=torch.stack([fused_pts0,fused_pts1],dim=0) 274 | labels={ 275 | "h":pair["0"]["height"], 276 | "w":pair["0"]["width"], 277 | "gt_default_pts":pt_0, 278 | "gt_duplicate_pts":pt_1, 279 | "gt_fuse_pts0":fused_pts0, 280 | "gt_fuse_pts1":fused_pts1, 281 | 282 | "gt_default_num":pair["0"]["cnt"], 283 | "gt_duplicate_num":pair["1"]["cnt"], 284 | "gt_fuse_num":fused_num, 285 | "video_name":video_name, 286 | "img_name1":img_name1, 287 | "img_name2":img_name2, 288 | } 289 | 290 | inputs={ 291 | "image_pair":x, 292 | "ref_pts":ref_pts, 293 | "ref_num":fused_num, 294 | "independ_pts0":torch.from_numpy(independ_pts0).float(), 295 | "independ_pts1":torch.from_numpy(independ_pts1).float(), 296 | "independ_num0":len(independ0_list), 297 | "independ_num1":len(independ1_list), 298 | } 299 | return inputs,labels 300 | 301 | def build_dataset(root,annotation_dir,max_len,train=False,step=20,interval=1,force_last=False): 302 | transforms=transform() 303 | dataset=PairDataset(root,annotation_dir,max_len,transforms=transforms,train=train,step=step,interval=interval,force_last=force_last) 304 | return dataset 305 | 306 | def build_video_dataset(root,annotation_dir): 307 | dataset=VideoDataset(root,annotation_dir,transforms=video_transform()) 308 | return dataset 309 | -------------------------------------------------------------------------------- /tri_eingine.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | import torch 6 | from scipy import spatial as ss 7 | 8 | from utils import SmoothedValue, get_total_grad_norm, reduce_dict 9 | from models.tri_sim_ot_b import similarity_cost 10 | import torch.nn.functional as F 11 | from scipy.optimize import linear_sum_assignment 12 | 13 | def hungarian(matrixTF): 14 | # matrix to adjacent matrix 15 | edges = np.argwhere(matrixTF) 16 | lnum, rnum = matrixTF.shape 17 | graph = [[] for _ in range(lnum)] 18 | for edge in edges: 19 | graph[edge[0]].append(edge[1]) 20 | 21 | # deep first search 22 | match = [-1 for _ in range(rnum)] 23 | vis = [-1 for _ in range(rnum)] 24 | 25 | def dfs(u): 26 | for v in graph[u]: 27 | if vis[v]: 28 | continue 29 | vis[v] = True 30 | if match[v] == -1 or dfs(match[v]): 31 | match[v] = u 32 | return True 33 | return False 34 | 35 | # for loop 36 | ans = 0 37 | for a in range(lnum): 38 | for i in range(rnum): 39 | vis[i] = False 40 | if dfs(a): 41 | ans += 1 42 | 43 | # assignment matrix 44 | assign = np.zeros((lnum, rnum), dtype=bool) 45 | for i, m in enumerate(match): 46 | if m >= 0: 47 | assign[m, i] = True 48 | 49 | return ans, assign 50 | 51 | 52 | def compute_metrics(pred_pts, pred_num, gt_pts, gt_num, sigma): 53 | if len(pred_pts) != 0 and gt_num == 0: 54 | fp = len(pred_pts) 55 | fn = 0 56 | tp = 0 57 | if len(pred_pts) == 0 and gt_num != 0: 58 | fn = gt_num 59 | fp = 0 60 | tp = 0 61 | if len(pred_pts) != 0 and gt_num != 0: 62 | 63 | pred_pts = pred_pts.cpu().detach().numpy() 64 | gt_pts = gt_pts.cpu().detach().numpy() 65 | print(pred_pts.shape, gt_pts.shape) 66 | dist_matrix = ss.distance_matrix(pred_pts, gt_pts, p=2) 67 | match_matrix = np.zeros(dist_matrix.shape, dtype=bool) 68 | for i_pred_p in range(pred_num): 69 | pred_dist = dist_matrix[i_pred_p, :] 70 | match_matrix[i_pred_p, :] = pred_dist <= sigma 71 | 72 | tp, assign = hungarian(match_matrix) 73 | fn_gt_index = np.array(np.where(assign.sum(0) == 0))[0] 74 | tp_pred_index = np.array(np.where(assign.sum(1) == 1))[0] 75 | tp_gt_index = np.array(np.where(assign.sum(0) == 1))[0] 76 | fp_pred_index = np.array(np.where(assign.sum(1) == 0))[0] 77 | 78 | tp = tp_pred_index.shape[0] 79 | fp = fp_pred_index.shape[0] 80 | fn = fn_gt_index.shape[0] 81 | return tp, fp, fn 82 | 83 | 84 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, metric_logger: object, scaler: torch.cuda.amp.GradScaler, epoch, args): 85 | model.train() 86 | 87 | metric_logger.meters.clear() 88 | 89 | header = 'Epoch: [{}]'.format(epoch) 90 | metric_logger.set_header(header) 91 | # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 92 | for inputs, labels in metric_logger.log_every(data_loader): 93 | optimizer.zero_grad() 94 | for key in inputs.keys(): 95 | inputs[key] = inputs[key].to(args.gpu) 96 | 97 | # y1,y2 = model(inputs) 98 | # if args.distributed: 99 | # loss_dict = model.module.loss(y1,y2) 100 | # else: 101 | # loss_dict = model.loss(y1,y2) 102 | z1,z2,y1,y2 = model(inputs) 103 | if args.distributed: 104 | loss_dict = model.module.loss(z1,z2,y1,y2) 105 | else: 106 | loss_dict = model.loss(z1,z2,y1,y2) 107 | 108 | all_loss = loss_dict["all"] 109 | loss_dict_reduced = reduce_dict(loss_dict) 110 | all_loss_reduced = loss_dict_reduced["all"] 111 | loss_value = all_loss_reduced.item() 112 | 113 | scaler.scale(all_loss).backward() 114 | scaler.unscale_(optimizer) 115 | 116 | if args.Misc.clip_max_norm > 0: 117 | grad_total_norm = torch.nn.utils.clip_grad_norm_( 118 | model.parameters(), args.Misc.clip_max_norm) 119 | else: 120 | grad_total_norm = get_total_grad_norm(model.parameters(), 121 | args.Misc.clip_max_norm) 122 | scaler.step(optimizer) 123 | scaler.update() 124 | 125 | for k in loss_dict_reduced.keys(): 126 | metric_logger.update(**{k: loss_dict_reduced[k]}) 127 | metric_logger.update(loss=loss_value) 128 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 129 | metric_logger.update(grad_norm=grad_total_norm) 130 | # gather the stats from all processes 131 | metric_logger.synchronize_between_processes() 132 | print("Averaged stats:", metric_logger) 133 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 134 | 135 | @torch.no_grad() 136 | def evaluate_similarity(model,data_loader,metric_logger,epoch,args): 137 | model.eval() 138 | metric_logger.meters.clear() 139 | header="Test" 140 | metric_logger.set_header(header) 141 | cnt=0 142 | for inputs,labels in metric_logger.log_every(data_loader): 143 | cnt+=1 144 | for key in inputs.keys(): 145 | inputs[key] = inputs[key].to(args.gpu) 146 | # z1,z2=model(inputs) 147 | z1,z2,y1,y2=model(inputs) 148 | z1,z2,y1,y2=F.normalize(z1,dim=-1),F.normalize(z2,dim=-1),F.normalize(y1,dim=-1),F.normalize(y2,dim=-1) 149 | match_matrix=torch.bmm(torch.cat((z1,y1),dim=1),torch.cat((z2,y2),dim=1).transpose(1,2)) 150 | all_match=linear_sum_assignment(1-match_matrix.cpu().detach().numpy()[0]) 151 | 152 | pos_sim=1-similarity_cost(z1,z2).detach().cpu().numpy()[0] 153 | pos_match=np.argmax(pos_sim,axis=1) 154 | if cnt%100==0: 155 | print(cnt) 156 | print("pos_match") 157 | print(pos_match) 158 | print("all_match") 159 | print(all_match) 160 | pos_match_acc=pos_match==np.arange(pos_match.shape[0]) 161 | all_match_acc=all_match[1][:pos_match.shape[0]]==np.arange(pos_match.shape[0]) 162 | all_match_acc=all_match_acc.sum()/all_match_acc.shape[0] 163 | pos_match_acc=pos_match_acc.sum()/pos_match_acc.shape[0] 164 | metric_logger.update(pos_match_acc=pos_match_acc) 165 | metric_logger.update(all_match_acc=all_match_acc) 166 | 167 | metric_logger.synchronize_between_processes() 168 | print("Averaged stats:", metric_logger) 169 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 170 | @torch.no_grad() 171 | def evaluate_counting_and_locating(model, data_loader, metric_logger, epoch, args): 172 | model.eval() 173 | # criterion.eval() 174 | metric_logger.meters.clear() 175 | 176 | for prefix in ["default", "duplicate", "fuse"]: 177 | metric_logger.add_meter(f'{prefix}_mse', 178 | SmoothedValue(window_size=1, fmt='{value:.5f}')) 179 | metric_logger.add_meter(f'{prefix}_mae', 180 | SmoothedValue(window_size=1, fmt='{value:.1f}')) 181 | metric_logger.add_meter(f'{prefix}_tp', 182 | SmoothedValue(window_size=1, fmt='{value:.1f}')) 183 | metric_logger.add_meter(f'{prefix}_fp', 184 | SmoothedValue(window_size=1, fmt='{value:.1f}')) 185 | metric_logger.add_meter(f'{prefix}_fn', 186 | SmoothedValue(window_size=1, fmt='{value:.1f}')) 187 | metric_logger.add_meter(f'{prefix}_cnt', 188 | SmoothedValue(window_size=1, fmt='{value:.1f}')) 189 | header = "Test" 190 | metric_logger.set_header(header) 191 | sigma = 8 192 | 193 | for inputs, labels in metric_logger.log_every(data_loader): 194 | inputs = inputs.to(args.gpu) 195 | assert inputs.shape[0] == 1 196 | if args.distributed: 197 | out_dict = model.module.forward_points(inputs, threshold=0.9) 198 | else: 199 | out_dict = model.forward_points(inputs, threshold=0.9) 200 | metric_dict = { 201 | "cnt": torch.as_tensor(1., device=args.gpu), 202 | } 203 | for key in ["default_pts", "duplicate_pts", "fuse_pts"]: 204 | prefix = key.split("_")[0] 205 | gt_nums = labels[f"gt_{prefix}_num"].to(args.gpu).float() 206 | gt_pts = labels[f"gt_{prefix}_pts"][0,:gt_nums.long(),...].to(args.gpu).float() 207 | pred_pts = out_dict[key] 208 | mae = torch.abs(len(pred_pts)-gt_nums).data.mean() 209 | mse = ((len(pred_pts)-gt_nums)**2).data.mean() 210 | tp, fp, fn = compute_metrics( 211 | pred_pts, len(pred_pts), gt_pts, gt_nums, sigma) 212 | 213 | tp = torch.as_tensor(tp, device=args.gpu) 214 | fp = torch.as_tensor(fp, device=args.gpu) 215 | fn = torch.as_tensor(fn, device=args.gpu) 216 | metric_dict[f'{prefix}_mae'] = mae 217 | metric_dict[f'{prefix}_mse'] = mse 218 | metric_dict[f'{prefix}_tp'] = tp 219 | metric_dict[f'{prefix}_fp'] = fp 220 | metric_dict[f'{prefix}_fn'] = fn 221 | ######################################## 222 | loss_dict_reduced = reduce_dict(metric_dict, average=True) 223 | for k in loss_dict_reduced.keys(): 224 | metric_logger.update(**{k: loss_dict_reduced[k]}) 225 | 226 | metric_logger.synchronize_between_processes() 227 | print("Averaged stats:", metric_logger) 228 | stats = {k: meter.total for k, meter in metric_logger.meters.items()} 229 | print(metric_logger.meters["cnt"].total, metric_logger.meters["cnt"].count) 230 | for prefix in ["default", "duplicate", "fuse"]: 231 | tp = stats[f'{prefix}_tp'] 232 | fp = stats[f'{prefix}_fp'] 233 | fn = stats[f'{prefix}_fn'] 234 | ap = tp / (tp + fp + 1e-7) 235 | ar = tp / (tp + fn + 1e-7) 236 | f1 = 2 * ap * ar / (ap + ar + 1e-7) 237 | stats[f'{prefix}_ap'] = ap 238 | stats[f'{prefix}_ar'] = ar 239 | stats[f'{prefix}_f1'] = f1 240 | stats[f"{prefix}_mae"] = stats[f"{prefix}_mae"]/stats["cnt"] 241 | stats[f"{prefix}_mse"] = sqrt(stats[f"{prefix}_mse"]/stats["cnt"]) 242 | return stats 243 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pickle 4 | import random 5 | import subprocess 6 | import time 7 | from collections import defaultdict, deque 8 | from typing import List, Optional 9 | 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | # needed due to empty tensor bug in pytorch and torchvision 0.5 15 | import torchvision 16 | from torch import Tensor 17 | from torch.nn import functional as F 18 | # if float(torchvision.__version__[:3]) < 0.5: 19 | # import math 20 | 21 | # from torchvision.ops.misc import _NewEmptyTensorOp 22 | 23 | # def _check_size_scale_factor(dim, size, scale_factor): 24 | # # type: (int, Optional[List[int]], Optional[float]) -> None 25 | # if size is None and scale_factor is None: 26 | # raise ValueError("either size or scale_factor should be defined") 27 | # if size is not None and scale_factor is not None: 28 | # raise ValueError( 29 | # "only one of size or scale_factor should be defined") 30 | # if not (scale_factor is not None and len(scale_factor) != dim): 31 | # raise ValueError( 32 | # "scale_factor shape must match input shape. " 33 | # "Input is {}D, scale_factor size is {}".format( 34 | # dim, len(scale_factor)) 35 | # ) 36 | 37 | # def _output_size(dim, input, size, scale_factor): 38 | # # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] 39 | # assert dim == 2 40 | # _check_size_scale_factor(dim, size, scale_factor) 41 | # if size is not None: 42 | # return size 43 | # # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat 44 | # assert scale_factor is not None and isinstance( 45 | # scale_factor, (int, float)) 46 | # scale_factors = [scale_factor, scale_factor] 47 | # # math.floor might return float in py2.7 48 | # return [ 49 | # int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) 50 | # ] 51 | # elif float(torchvision.__version__[:3]) < 0.7: 52 | # from torchvision.ops import _new_empty_tensor 53 | # from torchvision.ops.misc import _output_size 54 | 55 | 56 | def set_randomseed(seed): 57 | torch.manual_seed(seed) 58 | np.random.seed(seed) 59 | random.seed(seed) 60 | 61 | 62 | class SmoothedValue(object): 63 | """Track a series of values and provide access to smoothed values over a 64 | window or the global series average. 65 | """ 66 | 67 | def __init__(self, window_size=20, fmt=None): 68 | if fmt is None: 69 | fmt = "{median:.4f} ({global_avg:.4f})" 70 | self.deque = deque(maxlen=window_size) 71 | self.total = 0.0 72 | self.count = 0 73 | self.fmt = fmt 74 | 75 | def update(self, value, n=1): 76 | self.deque.append(value) 77 | self.count += n 78 | self.total += value * n 79 | 80 | def synchronize_between_processes(self): 81 | """ 82 | Warning: does not synchronize the deque! 83 | """ 84 | if not is_dist_avail_and_initialized(): 85 | return 86 | t = torch.tensor([self.count, self.total], 87 | dtype=torch.float64, device='cuda') 88 | dist.barrier() 89 | dist.all_reduce(t) 90 | t = t.tolist() 91 | self.count = int(t[0]) 92 | self.total = t[1] 93 | 94 | @property 95 | def median(self): 96 | d = torch.tensor(list(self.deque)) 97 | return d.median().item() 98 | 99 | @property 100 | def avg(self): 101 | d = torch.tensor(list(self.deque), dtype=torch.float32) 102 | return d.mean().item() 103 | 104 | @property 105 | def global_avg(self): 106 | return self.total / self.count 107 | @property 108 | def global_total(self): 109 | return self.total 110 | @property 111 | def max(self): 112 | return max(self.deque) 113 | 114 | @property 115 | def value(self): 116 | return self.deque[-1] 117 | 118 | def __str__(self): 119 | return self.fmt.format( 120 | median=self.median, 121 | avg=self.avg, 122 | global_avg=self.global_avg, 123 | max=self.max, 124 | value=self.value) 125 | 126 | 127 | def all_gather(data): 128 | """ 129 | Run all_gather on arbitrary picklable data (not necessarily tensors) 130 | Args: 131 | data: any picklable object 132 | Returns: 133 | list[data]: list of data gathered from each rank 134 | """ 135 | world_size = get_world_size() 136 | if world_size == 1: 137 | return [data] 138 | 139 | # serialized to a Tensor 140 | buffer = pickle.dumps(data) 141 | storage = torch.ByteStorage.from_buffer(buffer) 142 | tensor = torch.ByteTensor(storage).to("cuda") 143 | 144 | # obtain Tensor size of each rank 145 | local_size = torch.tensor([tensor.numel()], device="cuda") 146 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 147 | dist.all_gather(size_list, local_size) 148 | size_list = [int(size.item()) for size in size_list] 149 | max_size = max(size_list) 150 | 151 | # receiving Tensor from all ranks 152 | # we pad the tensor because torch all_gather does not support 153 | # gathering tensors of different shapes 154 | tensor_list = [] 155 | for _ in size_list: 156 | tensor_list.append(torch.empty( 157 | (max_size,), dtype=torch.uint8, device="cuda")) 158 | if local_size != max_size: 159 | padding = torch.empty(size=(max_size - local_size,), 160 | dtype=torch.uint8, device="cuda") 161 | tensor = torch.cat((tensor, padding), dim=0) 162 | dist.all_gather(tensor_list, tensor) 163 | 164 | data_list = [] 165 | for size, tensor in zip(size_list, tensor_list): 166 | buffer = tensor.cpu().numpy().tobytes()[:size] 167 | data_list.append(pickle.loads(buffer)) 168 | 169 | return data_list 170 | 171 | 172 | def reduce_dict(input_dict, average=True): 173 | """ 174 | Args: 175 | input_dict (dict): all the values will be reduced 176 | average (bool): whether to do average or sum 177 | Reduce the values in the dictionary from all processes so that all processes 178 | have the averaged results. Returns a dict with the same fields as 179 | input_dict, after reduction. 180 | """ 181 | world_size = get_world_size() 182 | if world_size < 2: 183 | return input_dict 184 | with torch.no_grad(): 185 | names = [] 186 | values = [] 187 | # sort the keys so that they are consistent across processes 188 | for k in sorted(input_dict.keys()): 189 | names.append(k) 190 | values.append(input_dict[k]) 191 | values = torch.stack(values, dim=0) 192 | dist.all_reduce(values) 193 | if average: 194 | values /= world_size 195 | reduced_dict = {k: v for k, v in zip(names, values)} 196 | return reduced_dict 197 | 198 | 199 | class MetricLogger(object): 200 | def __init__(self,args ): 201 | self.meters = defaultdict(SmoothedValue) 202 | self.delimiter = args.delimiter 203 | self.print_freq = args.print_freq 204 | self.header=args.header 205 | 206 | def update(self, **kwargs): 207 | for k, v in kwargs.items(): 208 | if isinstance(v, torch.Tensor): 209 | v = v.item() 210 | assert isinstance(v, (float, int)) 211 | self.meters[k].update(v) 212 | 213 | def __getattr__(self, attr): 214 | if attr in self.meters: 215 | return self.meters[attr] 216 | if attr in self.__dict__: 217 | return self.__dict__[attr] 218 | raise AttributeError("'{}' object has no attribute '{}'".format( 219 | type(self).__name__, attr)) 220 | 221 | def __str__(self): 222 | loss_str = [] 223 | for name, meter in self.meters.items(): 224 | loss_str.append( 225 | "{}: {}".format(name, str(meter)) 226 | ) 227 | return self.delimiter.join(loss_str) 228 | 229 | def synchronize_between_processes(self): 230 | for meter in self.meters.values(): 231 | meter.synchronize_between_processes() 232 | 233 | def add_meter(self, name, meter): 234 | self.meters[name] = meter 235 | def set_header(self,header): 236 | self.header=header 237 | def log_every(self, iterable): 238 | i = 0 239 | 240 | start_time = time.time() 241 | end = time.time() 242 | iter_time = SmoothedValue(fmt='{avg:.4f}') 243 | data_time = SmoothedValue(fmt='{avg:.4f}') 244 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 245 | if torch.cuda.is_available(): 246 | log_msg = self.delimiter.join([ 247 | self.header, 248 | '[{0' + space_fmt + '}/{1}]', 249 | 'eta: {eta}', 250 | '{meters}', 251 | 'time: {time}', 252 | 'data: {data}', 253 | 'max mem: {memory:.0f}' 254 | ]) 255 | else: 256 | log_msg = self.delimiter.join([ 257 | self.header, 258 | '[{0' + space_fmt + '}/{1}]', 259 | 'eta: {eta}', 260 | '{meters}', 261 | 'time: {time}', 262 | 'data: {data}' 263 | ]) 264 | MB = 1024.0 * 1024.0 265 | for obj in iterable: 266 | data_time.update(time.time() - end) 267 | yield obj 268 | iter_time.update(time.time() - end) 269 | if i % self.print_freq == 0 or i == len(iterable) - 1: 270 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 271 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 272 | if torch.cuda.is_available(): 273 | print(log_msg.format( 274 | i, len(iterable), eta=eta_string, 275 | meters=str(self), 276 | time=str(iter_time), data=str(data_time), 277 | memory=torch.cuda.max_memory_allocated() / MB)) 278 | else: 279 | print(log_msg.format( 280 | i, len(iterable), eta=eta_string, 281 | meters=str(self), 282 | time=str(iter_time), data=str(data_time))) 283 | i += 1 284 | end = time.time() 285 | total_time = time.time() - start_time 286 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 287 | print('{} Total time: {} ({:.4f} s / it)'.format( 288 | self.header, total_time_str, total_time / len(iterable))) 289 | 290 | 291 | def get_sha(): 292 | cwd = os.path.dirname(os.path.abspath(__file__)) 293 | 294 | def _run(command): 295 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 296 | sha = 'N/A' 297 | diff = "clean" 298 | branch = 'N/A' 299 | try: 300 | sha = _run(['git', 'rev-parse', 'HEAD']) 301 | subprocess.check_output(['git', 'diff'], cwd=cwd) 302 | diff = _run(['git', 'diff-index', 'HEAD']) 303 | diff = "has uncommited changes" if diff else "clean" 304 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 305 | except Exception: 306 | pass 307 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 308 | return message 309 | 310 | 311 | def collate_fn(batch): 312 | batch = list(zip(*batch)) 313 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 314 | return tuple(batch) 315 | 316 | 317 | def _max_by_axis(the_list): 318 | # type: (List[List[int]]) -> List[int] 319 | maxes = the_list[0] 320 | for sublist in the_list[1:]: 321 | for index, item in enumerate(sublist): 322 | maxes[index] = max(maxes[index], item) 323 | return maxes 324 | 325 | 326 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 327 | # TODO make this more general 328 | if tensor_list[0].ndim == 3: 329 | # TODO make it support different-sized images 330 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 331 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 332 | batch_shape = [len(tensor_list)] + max_size 333 | b, c, h, w = batch_shape 334 | dtype = tensor_list[0].dtype 335 | device = tensor_list[0].device 336 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 337 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 338 | for img, pad_img, m in zip(tensor_list, tensor, mask): 339 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 340 | m[: img.shape[1], :img.shape[2]] = False 341 | else: 342 | raise ValueError('not supported') 343 | return NestedTensor(tensor, mask) 344 | 345 | 346 | class NestedTensor(object): 347 | def __init__(self, tensors, mask: Optional[Tensor]): 348 | self.tensors = tensors 349 | self.mask = mask 350 | 351 | def to(self, device, non_blocking=False): 352 | ## type: (Device) -> NestedTensor # noqa 353 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 354 | mask = self.mask 355 | if mask is not None: 356 | assert mask is not None 357 | cast_mask = mask.to(device, non_blocking=non_blocking) 358 | else: 359 | cast_mask = None 360 | return NestedTensor(cast_tensor, cast_mask) 361 | 362 | def record_stream(self, *args, **kwargs): 363 | self.tensors.record_stream(*args, **kwargs) 364 | if self.mask is not None: 365 | self.mask.record_stream(*args, **kwargs) 366 | 367 | def decompose(self): 368 | return self.tensors, self.mask 369 | 370 | def __repr__(self): 371 | return str(self.tensors) 372 | 373 | 374 | def setup_for_distributed(is_master): 375 | """ 376 | This function disables printing when not in master process 377 | """ 378 | import builtins as __builtin__ 379 | builtin_print = __builtin__.print 380 | 381 | def print(*args, **kwargs): 382 | force = kwargs.pop('force', False) 383 | if is_master or force: 384 | builtin_print(*args, **kwargs) 385 | 386 | __builtin__.print = print 387 | 388 | 389 | def is_dist_avail_and_initialized(): 390 | if not dist.is_available(): 391 | return False 392 | if not dist.is_initialized(): 393 | return False 394 | return True 395 | 396 | 397 | def get_world_size(): 398 | if not is_dist_avail_and_initialized(): 399 | return 1 400 | return dist.get_world_size() 401 | 402 | 403 | def get_rank(): 404 | if not is_dist_avail_and_initialized(): 405 | return 0 406 | return dist.get_rank() 407 | 408 | 409 | def get_local_size(): 410 | if not is_dist_avail_and_initialized(): 411 | return 1 412 | return int(os.environ['LOCAL_SIZE']) 413 | 414 | 415 | def get_local_rank(): 416 | if not is_dist_avail_and_initialized(): 417 | return 0 418 | return int(os.environ['LOCAL_RANK']) 419 | 420 | 421 | def is_main_process(): 422 | return get_rank() == 0 423 | 424 | 425 | def save_on_master(*args, **kwargs): 426 | if is_main_process(): 427 | torch.save(*args, **kwargs) 428 | 429 | 430 | def init_distributedrun_mode(args): 431 | if "LOCAL_RANK" in os.environ: 432 | args.local_rank = int(os.environ["LOCAL_RANK"]) 433 | args.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 434 | 435 | 436 | def init_distributed_mode(args): 437 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 438 | args.rank = int(os.environ["RANK"]) 439 | args.world_size = int(os.environ['WORLD_SIZE']) 440 | args.gpu = int(os.environ['LOCAL_RANK']) 441 | args.dist_url = 'env://' 442 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 443 | elif 'SLURM_PROCID' in os.environ: 444 | proc_id = int(os.environ['SLURM_PROCID']) 445 | ntasks = int(os.environ['SLURM_NTASKS']) 446 | node_list = os.environ['SLURM_NODELIST'] 447 | num_gpus = torch.cuda.device_count() 448 | addr = subprocess.getoutput( 449 | 'scontrol show hostname {} | head -n1'.format(node_list)) 450 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 451 | os.environ['MASTER_ADDR'] = addr 452 | os.environ['WORLD_SIZE'] = str(ntasks) 453 | os.environ['RANK'] = str(proc_id) 454 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 455 | os.environ['LOCAL_SIZE'] = str(num_gpus) 456 | args.dist_url = 'env://' 457 | args.world_size = ntasks 458 | args.rank = proc_id 459 | args.gpu = proc_id % num_gpus 460 | else: 461 | print('Not using distributed mode') 462 | args.distributed = False 463 | args.gpu=0 464 | return 465 | 466 | args.distributed = True 467 | 468 | torch.cuda.set_device(args.gpu) 469 | args.dist_backend = 'nccl' 470 | print('| distributed init (rank {}): {}'.format( 471 | args.rank, args.dist_url), flush=True) 472 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 473 | world_size=args.world_size, rank=args.rank) 474 | torch.distributed.barrier() 475 | setup_for_distributed(args.rank == 0) 476 | 477 | 478 | @torch.no_grad() 479 | def accuracy(output, target, topk=(1,)): 480 | """Computes the precision@k for the specified values of k""" 481 | if target.numel() == 0: 482 | return [torch.zeros([], device=output.device)] 483 | maxk = max(topk) 484 | batch_size = target.size(0) 485 | 486 | _, pred = output.topk(maxk, 1, True, True) 487 | pred = pred.t() 488 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 489 | 490 | res = [] 491 | for k in topk: 492 | correct_k = correct[:k].view(-1).float().sum(0) 493 | res.append(correct_k.mul_(100.0 / batch_size)) 494 | return res 495 | 496 | 497 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 498 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 499 | """ 500 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 501 | This will eventually be supported natively by PyTorch, and this 502 | class can go away. 503 | """ 504 | # if float(torchvision.__version__[:3]) < 0.7: 505 | # if input.numel() > 0: 506 | # return torch.nn.functional.interpolate( 507 | # input, size, scale_factor, mode, align_corners 508 | # ) 509 | 510 | # output_shape = _output_size(2, input, size, scale_factor) 511 | # output_shape = list(input.shape[:-2]) + list(output_shape) 512 | # if float(torchvision.__version__[:3]) < 0.5: 513 | # return _NewEmptyTensorOp.apply(input, output_shape) 514 | # return _new_empty_tensor(input, output_shape) 515 | # else: 516 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 517 | 518 | 519 | def get_total_grad_norm(parameters, norm_type=2): 520 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 521 | norm_type = float(norm_type) 522 | device = parameters[0].grad.device 523 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 524 | norm_type) 525 | return total_norm 526 | 527 | 528 | def inverse_sigmoid(x, eps=1e-5): 529 | x = x.clamp(min=0, max=1) 530 | x1 = x.clamp(min=eps) 531 | x2 = (1 - x).clamp(min=eps) 532 | return torch.log(x1/x2) 533 | 534 | 535 | def predict_map2coord(output, threshold=0.5): 536 | pred_dmap_low = F.relu(output) 537 | pred_dmap_low = F.avg_pool2d(pred_dmap_low, 2) * 4 538 | 539 | # pred_dmap = F.avg_pool2d(output, 2) * 4 540 | # draw_txt(pred_dmap, os.path.join(save_txt_dir, f"{img_id_list[idx]}_{h}_{w}_{h1}_{w1}.txt")) 541 | 542 | threshold = 0.5 543 | H, W = pred_dmap_low.shape[-2], pred_dmap_low.shape[-1] 544 | pred_map_unflod = F.unfold(pred_dmap_low, kernel_size=3, padding=1) 545 | pred_max = pred_map_unflod.max(dim=1, keepdim=True)[1] 546 | pred_max = (pred_max == 3 ** 2 // 2).reshape(1, 1, H, W) 547 | # pred_max = (pred_max == 3 ** 2 // 2).reshape(1, 1, H, W).long() 548 | # draw_map(pred_max.float(), os.path.join(save_dir, f"{idx}_pred_max.png")) 549 | 550 | kernel3x3 = torch.ones(1, 1, 3, 3).float().cuda() 551 | pred_cnt = F.conv2d(pred_dmap_low, weight=kernel3x3, padding=1) 552 | pred_filter = (pred_cnt > threshold) 553 | # pred_filter = (pred_cnt > threshold).long() 554 | # draw_map(pred_filter.float(), os.path.join(save_dir, f"{idx}_pred_filter.png")) 555 | 556 | pred_filter = pred_filter * pred_max 557 | # draw_map(pred_filter.float(), os.path.join(save_dir, f"{idx}_pred_filter_after.png")) 558 | # pred_filter = pred_filter.long().detach().cpu().squeeze(0).squeeze(0) 559 | pred_filter = pred_filter.detach().cpu().squeeze(0).squeeze(0) 560 | 561 | pred_coord_weight = F.normalize(pred_map_unflod, p=1, dim=1) 562 | coord_h = torch.arange(H).reshape(1, 1, H, 1).repeat(1, 1, 1, W).float().cuda() 563 | coord_w = torch.arange(W).reshape(1, 1, 1, W).repeat(1, 1, H, 1).float().cuda() 564 | coord_h_unflod = F.unfold(coord_h, kernel_size=3, padding=1) 565 | coord_h_pred = (coord_h_unflod * pred_coord_weight).sum(dim=1, keepdim=True).reshape(H, W).detach().cpu() 566 | coord_w_unflod = F.unfold(coord_w, kernel_size=3, padding=1) 567 | coord_w_pred = (coord_w_unflod * pred_coord_weight).sum(dim=1, keepdim=True).reshape(H, W).detach().cpu() 568 | 569 | # print("filter", pred_filter.shape) 570 | coord_h = coord_h_pred[pred_filter].unsqueeze(1) 571 | coord_w = coord_w_pred[pred_filter].unsqueeze(1) 572 | coord = torch.cat((coord_h, coord_w), dim=1).numpy() 573 | return coord 574 | --------------------------------------------------------------------------------