├── LICENSE ├── README.md ├── data ├── VCDB │ ├── label_file.json │ └── pair_file.csv └── VCSL │ ├── label_file.json │ └── pair_file.csv ├── evaluation.py ├── exps └── exp.py ├── fig.png ├── metric └── eval.py ├── requirements.txt ├── run.py ├── scripts ├── eval_TransVCL.sh └── test_TransVCL.sh └── transvcl ├── models ├── __init__.py ├── darknet.py ├── linear_attention.py ├── network_blocks.py ├── sim_gen.py ├── transvcl_model.py ├── yolo_head.py └── yolo_pafpn.py ├── utils ├── __init__.py ├── boxes.py └── dist.py └── weights └── pretrained_models.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 transvcl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransVCL: Attention-enhanced Video Copy Localization Network with Flexible Supervision [AAAI2023 Oral] 2 | 3 | ## Introduction 4 | TransVCL is a novel network with joint optimization of multiple components for segment-level video copy detection. It achieves the state-of-the-art performance in video copy segment localization benchmark and can also be flexibly extended to semi-supervised settings. This paper is accepted by AAAI2023. The details of TransVCL are indicated in [arXiv Link](https://arxiv.org/abs/2211.13090). 5 | ![vcsl](./fig.png) 6 | 7 | ## Preparations 8 | * Download or extract frame-level video features and put them under directory `data/${DATASET}/features/`. 9 | Features of VCSL dataset are given in [VCSL benchmark](https://github.com/alipay/VCSL), and features of VCDB dataset need 10 | to be extracted following [ISC competition](https://github.com/lyakaap/ISC21-Descriptor-Track-1st). 11 | * Download our pretrained model in `transvcl/weights/pretrained_models.txt`. We provides two models (model_1.pth and model_2.pth). model_1.pth is trained in fully supervised setting on VCSL 12 | dataset and you can reproduce results in Table 1. model_2 is trained with weakly semi-supervised setting on 13 | VCSL and FIVR&SVD and you can reproduce results in Table 5. 14 | * Install python requirements in `requirements.txt`. 15 | 16 | ## Run and Evaluation 17 | Run TransVCL network on given models and datasets as: 18 | ```bash 19 | bash scripts/test_TransVCL.sh 20 | ``` 21 | You can obtain a result json file with copied segments' temporal boundaries and their confidence score. 22 | 23 | Then run evaluation scripts on above predicted file as: 24 | ```bash 25 | bash scripts/eval_TransVCL.sh 26 | ``` 27 | You should see the output performance. In the case of VCSL and model_1, the result is 28 | ``` 29 | - start loading... 30 | - result file: results/model/VCSL/result.json, data cnt: 55530, macro-Recall: 65.59%, macro-Precision: 67.46%, F1: 66.51% 31 | ``` 32 | 33 | ## Benchmark 34 | After executing the above several steps, the overall segment-level precision/recall performance of TransVCL on [VCSL benchmark](https://github.com/alipay/VCSL#benchmark) is indicated below: 35 | 36 | | Performance | Recall | Precision | Fscore | 37 | |:-------------|:------:|:---------:|:---------:| 38 | | HV | 86.94 | 36.83 | 51.73 | 39 | | TN | 75.25 | 51.80 | 61.36 | 40 | | DP | 49.48 | 60.61 | 54.48 | 41 | | DTW | 45.10 | 56.67 | 50.23 | 42 | | SPD | 56.49 | 68.60 | 61.96 | 43 | | **TransVCL** | 65.59 | 67.46 | **66.51** | 44 | 45 | 46 | ## Acknowledgements 47 | We referenced the repos below for the code 48 | - [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX). 49 | - [LoFTR](https://github.com/zju3dv/LoFTR). 50 | - [VCSL](https://github.com/alipay/VCSL). 51 | 52 | Thanks for their wonderful works. 53 | 54 | ## Cite TransVCL 55 | If the code is helpful for your work, please cite our paper 56 | ``` 57 | @inproceedings{he2023transvcl, 58 | title={TransVCL: Attention-enhanced Video Copy Localization Network with Flexible Supervision}, 59 | author={He, Sifeng and Yue, He and Lu, Minlong and others}, 60 | booktitle={37th AAAI Conference on Artificial Intelligence: AAAI 2023}, 61 | year={2023} 62 | } 63 | 64 | @inproceedings{he2022large, 65 | title={A Large-scale Comprehensive Dataset and Copy-overlap Aware Evaluation Protocol for Segment-level Video Copy Detection}, 66 | author={He, Sifeng and Yang, Xudong and Jiang, Chen and others}, 67 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 68 | pages={21086--21095}, 69 | year={2022} 70 | } 71 | ``` 72 | 73 | 74 | ## License 75 | The code is released under MIT license 76 | 77 | ```bash 78 | MIT License 79 | 80 | Copyright (c) 2023 Ant Group 81 | 82 | Permission is hereby granted, free of charge, to any person obtaining a copy 83 | of this software and associated documentation files (the "Software"), to deal 84 | in the Software without restriction, including without limitation the rights 85 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 86 | copies of the Software, and to permit persons to whom the Software is 87 | furnished to do so, subject to the following conditions: 88 | 89 | The above copyright notice and this permission notice shall be included in all 90 | copies or substantial portions of the Software. 91 | 92 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 93 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 94 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 95 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 96 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 97 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 98 | SOFTWARE. 99 | ``` 100 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import os 4 | import json 5 | import numpy as np 6 | 7 | from metric.eval import * 8 | from multiprocessing import Pool 9 | from loguru import logger 10 | 11 | 12 | def process_conf(eval): 13 | key, gt, pred_box = eval["name"], eval['gt'], np.array(eval['pred']) 14 | try: 15 | return {"name": key, "gt": gt, "pred": pred_box[pred_box[:, -1] > conf][:, :4].astype(int).tolist()} 16 | except: 17 | return {"name": key, "gt": gt, "pred": []} 18 | 19 | 20 | def eval(input_dict): 21 | gt_box = np.array(input_dict["gt"]) 22 | pred_box = np.array(input_dict["pred"]) 23 | result_dict = precision_recall(pred_box, gt_box) 24 | result_dict["name"] = input_dict["name"] 25 | return result_dict 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--anno-file", type=str, default=None, help="gt label file") 31 | parser.add_argument("--pred-file", type=str, default=None, help="result dir of video segment prediction") 32 | parser.add_argument("--test-file", type=str, default=None, help="test pair list of query and reference videos") 33 | parser.add_argument("--pool-size", type=int, default=16, help="multiprocess pool size of evaluation") 34 | parser.add_argument("--conf", type=float, default=0.1, help="input with conf") 35 | 36 | args = parser.parse_args() 37 | 38 | logger.info(f"start loading...") 39 | 40 | df = pd.read_csv(args.test_file) 41 | split_pairs = set([f"{q}-{r}" for q, r in zip(df.query_id.values, df.reference_id.values)]) 42 | 43 | gt = json.load(open(args.anno_file)) 44 | key_list = [key for key in gt] 45 | 46 | process_pool = Pool(args.pool_size) 47 | 48 | pred_dict = json.load(open(args.pred_file)) 49 | eval_list = [] 50 | for key in split_pairs: 51 | if key in gt: 52 | if key in pred_dict: 53 | eval_list += [{"name": key, "gt": gt[key], "pred": pred_dict[key]}] 54 | else: 55 | eval_list += [{"name": key, "gt": gt[key], "pred": []}] 56 | else: 57 | if key in pred_dict: 58 | eval_list += [{"name": key, "gt": [], "pred": pred_dict[key]}] 59 | else: 60 | eval_list += [{"name": key, "gt": [], "pred": []}] 61 | 62 | if args.conf != 0: 63 | conf = args.conf 64 | process_pool = Pool(args.pool_size) 65 | eval_list = process_pool.map(process_conf, eval_list) 66 | 67 | process_pool = Pool(args.pool_size) 68 | result_list = process_pool.map(eval, eval_list) 69 | result_dict = {i['name']: i for i in result_list} 70 | r, p = evaluate_overall(result_dict) 71 | 72 | logger.info(f"result file: {args.pred_file}, " 73 | f"data cnt: {len(result_list)}, " 74 | f"macro-Recall: {r:.2%}, " 75 | f"macro-Precision: {p:.2%}, " 76 | f"F1: {2 * r * p / (r + p + 1e-6):.2%}") 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /exps/exp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | import torch.nn as nn 4 | import json 5 | from transvcl.utils import wait_for_the_master, get_local_rank 6 | from transvcl.models import TransVCL, YOLOPAFPN, YOLOXHead 7 | 8 | class Exp(object): 9 | def __init__(self): 10 | super(Exp, self).__init__() 11 | self.depth = 0.33 12 | self.width = 0.50 13 | self.act = "silu" 14 | self.num_classes = 1 15 | self.eval_interval = 1 16 | 17 | self.default_feat_length = 1200 18 | self.default_feat_dim = 256 19 | 20 | self.vta_config = { 21 | 'd_model': 256, 22 | 'nhead': 8, 23 | 'layer_names': ['self', 'cross'] * 1, 24 | 'attention': 'linear', 25 | 'match_type': 'dual_softmax', 26 | 'dsmax_temperature': 0.1, 27 | 'keep_ratio': False, 28 | 'unsupervised_weight': 0.5 29 | } 30 | 31 | def get_model(self): 32 | def init_transvcl(M): 33 | for m in M.modules(): 34 | if isinstance(m, nn.BatchNorm2d): 35 | m.eps = 1e-3 36 | m.momentum = 0.03 37 | 38 | if getattr(self, "model", None) is None: 39 | in_channels = [256, 512, 1024] 40 | backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act) 41 | head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act) 42 | self.model = TransVCL(self.vta_config, backbone, head) 43 | 44 | self.model.apply(init_transvcl) 45 | self.model.head.initialize_biases(1e-2) 46 | return self.model 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transvcl/TransVCL/4b57976da5ad1b38f999c96aba315e63a272acdf/fig.png -------------------------------------------------------------------------------- /metric/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import numpy as np 5 | from typing import Dict, Any, Tuple 6 | 7 | 8 | def seg_len(segments: np.ndarray, type: str = 'union') -> float: 9 | """ 10 | get accumulated length of all line segments 11 | union: the intersection area is calculated only once 12 | sum: the intersection area is calculated several times 13 | Parameters 14 | ---------- 15 | segments : shape (N, 2) 16 | each row is a segment with (start, end) 17 | 18 | Returns 19 | ------- 20 | len : float 21 | total length of the union set of the segments 22 | """ 23 | 24 | if type != 'union': 25 | return np.sum(segments[:, 1] - segments[:, 0]).item() 26 | 27 | segments_to_sum = [] 28 | # sort by start coord 29 | segments = sorted(segments.tolist(), key=lambda x: x[0]) 30 | for segment in segments: 31 | if len(segments_to_sum) == 0: 32 | segments_to_sum.append(segment) 33 | continue 34 | 35 | last_segment = segments_to_sum[-1] 36 | # if no overlap, append then merge 37 | if last_segment[1] < segment[0]: 38 | segments_to_sum.append(segment) 39 | else: 40 | union_segment = [min(last_segment[0], segment[0]), max(last_segment[1], segment[1])] 41 | segments_to_sum[-1] = union_segment 42 | 43 | segments_to_sum = np.array(segments_to_sum, dtype=np.float32) 44 | return np.sum(segments_to_sum[:, 1] - segments_to_sum[:, 0]).item() 45 | 46 | 47 | def calc_inter(pred_boxes: np.ndarray, gt_boxes: np.ndarray) -> (np.ndarray, np.ndarray): 48 | """ 49 | Calculate intersection boxes and areas of each pred and gt box 50 | Parameters 51 | ---------- 52 | pred_boxes : shape (N, 4) 53 | gt_boxes : shape (M, 4) 54 | box format top-left and bottom-right coords (x1, y1, x2, y2) 55 | 56 | Returns 57 | ------- 58 | inter_boxes : numpy.ndarray, shape (N, M, 4) 59 | intersection boxes of each pred and gt box 60 | inter_areas : numpy.ndarray, shape (N, M) 61 | intersection areas of each pred and gt box 62 | """ 63 | lt = np.maximum(pred_boxes[:, None, :2], gt_boxes[:, :2]) 64 | rb = np.minimum(pred_boxes[:, None, 2:], gt_boxes[:, 2:]) 65 | wh = np.maximum(rb - lt, 0) 66 | inter_boxes = np.concatenate((lt, rb), axis=2) 67 | inter_areas = wh[:, :, 0] * wh[:, :, 1] 68 | return inter_boxes, inter_areas 69 | 70 | 71 | def precision_recall(pred_boxes: np.ndarray, gt_boxes: np.ndarray): 72 | """ 73 | Segment level Precision/Recall evaluation for one video pair vta result 74 | pred_boxes shape(N, 4) indicates N predicted copied segments 75 | gt_boxes shape(M, 4) indicates M ground-truth labelled copied segments 76 | Parameters 77 | ---------- 78 | pred_boxes : shape (N, 4) 79 | gt_boxes : shape (M, 4) 80 | 81 | Returns 82 | ------- 83 | precision : float 84 | recall : float 85 | """ 86 | 87 | # abnormal assigned values for denominator = 0 88 | if len(pred_boxes) > 0 and len(gt_boxes) == 0: 89 | return {"precision": 0, "recall": 1} 90 | 91 | if len(pred_boxes) == 0 and len(gt_boxes) > 0: 92 | return {"precision": 1, "recall": 0} 93 | 94 | if len(pred_boxes) == 0 and len(gt_boxes) == 0: 95 | return {"precision": 1, "recall": 1} 96 | 97 | # intersection area calculation 98 | inter_boxes, inter_areas = calc_inter(pred_boxes, gt_boxes) 99 | 100 | sum_tp_w, sum_p_w, sum_tp_h, sum_p_h = 0, 0, 0, 0 101 | for pred_ind, inter_per_pred in enumerate(inter_areas): 102 | # for each pred-box, find the gt-boxes whose iou is > 0 with it 103 | pos_gt_inds = np.where(inter_per_pred > 0) 104 | if len(pos_gt_inds[0]) > 0: 105 | 106 | # union of all pred box along each side 107 | # tp: true positive 108 | sum_tp_w += seg_len(np.squeeze(inter_boxes[pred_ind, pos_gt_inds, :][:, :, [0, 2]], axis=0)) 109 | 110 | sum_tp_h += seg_len(np.squeeze(inter_boxes[pred_ind, pos_gt_inds, :][:, :, [1, 3]], axis=0)) 111 | 112 | sum_p_w = seg_len(pred_boxes[:, [0, 2]], type='sum') 113 | sum_p_h = seg_len(pred_boxes[:, [1, 3]], type='sum') 114 | precision_w = sum_tp_w / (sum_p_w + 1e-6) 115 | precision_h = sum_tp_h / (sum_p_h + 1e-6) 116 | 117 | sum_tp_w, sum_p_w, sum_tp_h, sum_p_h = 0, 0, 0, 0 118 | for gt_ind, inter_per_gt in enumerate(inter_areas.T): 119 | # for each gt-box, find the pred-boxes whose iou is > 0 with it 120 | pos_pred_inds = np.where(inter_per_gt > 0) 121 | if len(pos_pred_inds[0]) > 0: 122 | 123 | # union of all pred box along each side 124 | # tp: true positive 125 | sum_tp_w += seg_len(np.squeeze(inter_boxes[pos_pred_inds, gt_ind, :][:, :, [0, 2]], axis=0)) 126 | 127 | sum_tp_h += seg_len(np.squeeze(inter_boxes[pos_pred_inds, gt_ind, :][:, :, [1, 3]], axis=0)) 128 | 129 | sum_p_w = seg_len(gt_boxes[:, [0, 2]], type='sum') 130 | sum_p_h = seg_len(gt_boxes[:, [1, 3]], type='sum') 131 | recall_w = sum_tp_w / (sum_p_w + 1e-6) 132 | recall_h = sum_tp_h / (sum_p_h + 1e-6) 133 | 134 | return {"precision": precision_h * precision_w, "recall": recall_h * recall_w} 135 | 136 | 137 | 138 | def evaluate_overall(result_dict: Dict[str, Dict[str, Any]]) -> Tuple[float, float]: 139 | """ 140 | This metric indicates the overall performance on all the video pairs. 141 | Parameters 142 | ---------- 143 | result_dict: segment level Precision/Recall result of all the video pairs 144 | ratio: nums of positive samples / nums of negative samples 145 | 146 | Returns 147 | ------- 148 | recall, precision 149 | """ 150 | # The following metric directly filter out the result with abnormal assigned values 151 | # if len(pred_boxes) == 0, precision is 1 but it will not contribute to final precision metric. 152 | # max value of regular calculation is always != 1 since x / (x + 1e-6) < 1 153 | precision_list = [result_dict[i]['precision'] for i in result_dict if not (result_dict[i]['precision'] == 1)] 154 | # if len(gt_boxes) == 0, recall is 1 but it will not contribute to final recall metric. 155 | recall_list = [result_dict[i]['recall'] for i in result_dict if not (result_dict[i]['recall'] == 1)] 156 | r, p = sum(recall_list) / len(recall_list), sum(precision_list) / len(precision_list) 157 | 158 | return r, p 159 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | numpy 3 | opencv_python 4 | pandas 5 | torch 6 | torchvision 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from exps.exp import Exp 7 | from transvcl.utils import postprocess 8 | import numpy as np 9 | import os 10 | import json 11 | import pandas as pd 12 | from collections import defaultdict 13 | import argparse 14 | from loguru import logger 15 | 16 | def feat_paddding(feat: torch.Tensor, axis: int, new_size: int, fill_value: int = 0): 17 | pad_shape = list(feat.shape) 18 | pad_shape[axis] = max(0, new_size - pad_shape[axis]) 19 | feat_pad = torch.Tensor(*pad_shape).fill_(fill_value) 20 | return torch.cat([feat, feat_pad], dim=axis) 21 | 22 | def load_features_list(feat1, feat2, file_name): 23 | feat_length = 1200 24 | feat1_list, feat2_list = [], [] 25 | i, j = -1, -1 26 | for i in range(len(feat1) // feat_length): 27 | feat1_list.append(feat1[i * feat_length: (i + 1) * feat_length]) 28 | for j in range(len(feat2) // feat_length): 29 | feat2_list.append(feat2[j * feat_length: (j + 1) * feat_length]) 30 | if len(feat1) > (i + 1) * feat_length: 31 | feat1_list.append(feat1[(i + 1) * feat_length:]) 32 | if len(feat2) > (j + 1) * feat_length: 33 | feat2_list.append(feat2[(j + 1) * feat_length:]) 34 | batch_list = [] 35 | for i in range(len(feat1_list)): 36 | for j in range(len(feat2_list)): 37 | mask1, mask2 = np.zeros(feat_length, dtype=bool), np.zeros(feat_length, dtype=bool) 38 | mask1[:len(feat1_list[i])] = True 39 | mask2[:len(feat2_list[j])] = True 40 | 41 | feat1_padding = feat_paddding(torch.tensor(feat1_list[i]), 0, feat_length) 42 | feat2_padding = feat_paddding(torch.tensor(feat2_list[j]), 0, feat_length) 43 | 44 | img_info = [torch.tensor([len(feat1_list[i])]), torch.tensor([len(feat2_list[j])])] 45 | 46 | file_name_idx = file_name + "_" + str(i) + "_" + str(j) 47 | 48 | batch_list.append((feat1_padding, feat2_padding, torch.from_numpy(mask1), torch.from_numpy(mask2), img_info, file_name_idx)) 49 | 50 | return batch_list 51 | 52 | class SimFeatDataset(Dataset): 53 | def __init__(self, batch_list, **kwargs): 54 | self.batch_list = batch_list 55 | 56 | def __getitem__(self, item): 57 | return self.batch_list[item] 58 | 59 | def __len__(self): 60 | return len(self.batch_list) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | 66 | parser.add_argument("--model-file", type=str, default=None, help="TransVCL model file") 67 | parser.add_argument("--feat-dir", type=str, default=None, help="video feature dir") 68 | parser.add_argument("--feat-length", type=int, default=1200, help="feature length for TransVCL input") 69 | parser.add_argument("--test-file", type=str, default=None, help="test pair list of query and reference videos") 70 | parser.add_argument("--conf-thre", type=float, default=0.1, help="conf threshold of copied segments") 71 | parser.add_argument("--nms-thre", type=float, default=0.3, help="nms threshold of copied segments") 72 | parser.add_argument("--img-size", type=int, default=640, help="length for copied localization module") 73 | parser.add_argument("--load-batch", type=int, default=8192, help="batch size of loading features to CPU") 74 | parser.add_argument("--inference-batch", type=int, default=1024, help="batch size of TransVCL inference") 75 | parser.add_argument("--save-file", type=str, default=None, help="save json file of results") 76 | parser.add_argument("--device", type=int, default=None, help="GPU device") 77 | 78 | args = parser.parse_args() 79 | if args.device is not None: 80 | device = "cuda:" + str(args.device) 81 | else: 82 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 83 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8' 84 | 85 | feat_dir, test_file = args.feat_dir, args.test_file 86 | num_classes, confthre, nmsthre = 1, args.conf_thre, args.nms_thre 87 | img_size, feat_max_length = (args.img_size, args.img_size), args.feat_length 88 | 89 | df = pd.read_csv(test_file) 90 | process_list = [f"{q}-{r}" for q, r in zip(df.query_id.values, df.reference_id.values)] 91 | process_list = [file.split(".")[0] for file in process_list] 92 | 93 | result = defaultdict(list) 94 | 95 | exp = Exp() 96 | model = exp.get_model() 97 | 98 | model.eval() 99 | ckpt = torch.load(args.model_file, map_location="cpu") 100 | # load the model state dict 101 | model.load_state_dict(ckpt["model"]) 102 | # model.to(device) 103 | model = torch.nn.DataParallel(model.cuda()) 104 | 105 | batch_feat_list = [] 106 | for idx, process_img in enumerate(process_list): 107 | feat1_name, feat2_name = process_img.split("-")[0], process_img.split("-")[1] 108 | feat1_name, feat2_name = feat_dir + feat1_name + ".npy", feat_dir + feat2_name + ".npy" 109 | feat1, feat2 = np.load(feat1_name), np.load(feat2_name) 110 | batch_feat_list += load_features_list(feat1, feat2, process_img) 111 | loading_idx = args.load_batch 112 | if idx % loading_idx == loading_idx - 1 or idx == len(process_list) - 1: 113 | logger.info(f"finish {idx + 1} / {len(process_list)} of total feature loading") 114 | dataset = SimFeatDataset(batch_feat_list) 115 | bs = args.inference_batch 116 | dataloader_kwargs = {"batch_size": bs, "num_workers": 0} 117 | loader = DataLoader(dataset, **dataloader_kwargs) 118 | 119 | batch_feat_result, batch_global_result = {}, {} 120 | for idx, batch_data in enumerate(loader): 121 | if idx % 2 == 0: 122 | logger.info(f"starting {idx * bs} / {len(dataset)} of inference") 123 | feat1, feat2, mask1, mask2, img_info, file_name = batch_data 124 | feat1, feat2, mask1, mask2 = feat1.cuda(), feat2.cuda(), mask1.cuda(), mask2.cuda() 125 | with torch.no_grad(): 126 | model_outputs = model(feat1, feat2, mask1, mask2, file_name, img_info) 127 | outputs = postprocess( 128 | model_outputs[1], num_classes, confthre, 129 | nmsthre, class_agnostic=True 130 | ) 131 | 132 | for idx, output in enumerate(outputs): 133 | if output is not None: 134 | bboxes = output[:, :5].cpu() 135 | 136 | scale1, scale2 = img_info[0] / img_size[0], img_info[1] / img_size[1] 137 | bboxes[:, 0:4:2] *= scale2[idx] 138 | bboxes[:, 1:4:2] *= scale1[idx] 139 | batch_feat_result[file_name[idx]] = bboxes[:, (1, 0, 3, 2, 4)].tolist() 140 | else: 141 | batch_feat_result[file_name[idx]] = [[]] 142 | 143 | 144 | for img_name in batch_feat_result: 145 | img_file = img_name.split("_")[0] 146 | i, j = int(img_name.split("_")[1]), int(img_name.split("_")[2]) 147 | if batch_feat_result[img_name] != [[]]: 148 | for r in batch_feat_result[img_name]: 149 | result[img_file].append( 150 | [r[0] + i * feat_max_length, r[1] + j * feat_max_length, r[2] + i * feat_max_length, 151 | r[3] + j * feat_max_length, r[4]]) 152 | 153 | batch_feat_list = [] 154 | 155 | json.dump(result, open(args.save_file, "w")) 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /scripts/eval_TransVCL.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATASET=VCSL 3 | MODEL=model 4 | 5 | # Before eval_TransVCL, please first execute test_TransVCL.sh and obtain pred_file. 6 | 7 | python evaluation.py \ 8 | --anno-file data/${DATASET}/label_file.json \ 9 | --test-file data/${DATASET}/pair_file.csv \ 10 | --pred-file results/${MODEL}/${DATASET}/result.json -------------------------------------------------------------------------------- /scripts/test_TransVCL.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATASET=VCSL 3 | MODEL=model 4 | FEATDIR=data/VCSL/features/ 5 | 6 | python run.py \ 7 | --model-file transvcl/weights/${MODEL}.pth \ 8 | --feat-dir data/${DATASET}/features/ \ 9 | --test-file data/${DATASET}/pair_file.csv \ 10 | --save-file results/${MODEL}/${DATASET}/result.json -------------------------------------------------------------------------------- /transvcl/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transvcl_model import TransVCL, YOLOPAFPN, YOLOXHead -------------------------------------------------------------------------------- /transvcl/models/darknet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from torch import nn 5 | 6 | from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck 7 | 8 | 9 | class Darknet(nn.Module): 10 | # number of blocks from dark2 to dark5. 11 | depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} 12 | 13 | def __init__( 14 | self, 15 | depth, 16 | in_channels=3, 17 | stem_out_channels=32, 18 | out_features=("dark3", "dark4", "dark5"), 19 | ): 20 | """ 21 | Args: 22 | depth (int): depth of darknet used in model, usually use [21, 53] for this param. 23 | in_channels (int): number of input channels, for example, use 3 for RGB image. 24 | stem_out_channels (int): number of output channels of darknet stem. 25 | It decides channels of darknet layer2 to layer5. 26 | out_features (Tuple[str]): desired output layer name. 27 | """ 28 | super().__init__() 29 | assert out_features, "please provide output features of Darknet" 30 | self.out_features = out_features 31 | self.stem = nn.Sequential( 32 | BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"), 33 | *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), 34 | ) 35 | in_channels = stem_out_channels * 2 # 64 36 | 37 | num_blocks = Darknet.depth2blocks[depth] 38 | # create darknet with `stem_out_channels` and `num_blocks` layers. 39 | # to make model structure more clear, we don't use `for` statement in python. 40 | self.dark2 = nn.Sequential( 41 | *self.make_group_layer(in_channels, num_blocks[0], stride=2) 42 | ) 43 | in_channels *= 2 # 128 44 | self.dark3 = nn.Sequential( 45 | *self.make_group_layer(in_channels, num_blocks[1], stride=2) 46 | ) 47 | in_channels *= 2 # 256 48 | self.dark4 = nn.Sequential( 49 | *self.make_group_layer(in_channels, num_blocks[2], stride=2) 50 | ) 51 | in_channels *= 2 # 512 52 | 53 | self.dark5 = nn.Sequential( 54 | *self.make_group_layer(in_channels, num_blocks[3], stride=2), 55 | *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2), 56 | ) 57 | 58 | def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1): 59 | "starts with conv layer then has `num_blocks` `ResLayer`" 60 | return [ 61 | BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"), 62 | *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], 63 | ] 64 | 65 | def make_spp_block(self, filters_list, in_filters): 66 | m = nn.Sequential( 67 | *[ 68 | BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"), 69 | BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"), 70 | SPPBottleneck( 71 | in_channels=filters_list[1], 72 | out_channels=filters_list[0], 73 | activation="lrelu", 74 | ), 75 | BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"), 76 | BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"), 77 | ] 78 | ) 79 | return m 80 | 81 | def forward(self, x): 82 | outputs = {} 83 | x = self.stem(x) 84 | outputs["stem"] = x 85 | x = self.dark2(x) 86 | outputs["dark2"] = x 87 | x = self.dark3(x) 88 | outputs["dark3"] = x 89 | x = self.dark4(x) 90 | outputs["dark4"] = x 91 | x = self.dark5(x) 92 | outputs["dark5"] = x 93 | return {k: v for k, v in outputs.items() if k in self.out_features} 94 | 95 | 96 | class CSPDarknet(nn.Module): 97 | def __init__( 98 | self, 99 | dep_mul, 100 | wid_mul, 101 | out_features=("dark3", "dark4", "dark5"), 102 | depthwise=False, 103 | act="silu", 104 | ): 105 | super().__init__() 106 | assert out_features, "please provide output features of Darknet" 107 | self.out_features = out_features 108 | Conv = DWConv if depthwise else BaseConv 109 | 110 | base_channels = int(wid_mul * 64) # 64 111 | base_depth = max(round(dep_mul * 3), 1) # 3 112 | 113 | # stem 114 | self.stem = Focus(3, base_channels, ksize=3, act=act) 115 | 116 | # dark2 117 | self.dark2 = nn.Sequential( 118 | Conv(base_channels, base_channels * 2, 3, 2, act=act), 119 | CSPLayer( 120 | base_channels * 2, 121 | base_channels * 2, 122 | n=base_depth, 123 | depthwise=depthwise, 124 | act=act, 125 | ), 126 | ) 127 | 128 | # dark3 129 | self.dark3 = nn.Sequential( 130 | Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), 131 | CSPLayer( 132 | base_channels * 4, 133 | base_channels * 4, 134 | n=base_depth * 3, 135 | depthwise=depthwise, 136 | act=act, 137 | ), 138 | ) 139 | 140 | # dark4 141 | self.dark4 = nn.Sequential( 142 | Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), 143 | CSPLayer( 144 | base_channels * 8, 145 | base_channels * 8, 146 | n=base_depth * 3, 147 | depthwise=depthwise, 148 | act=act, 149 | ), 150 | ) 151 | 152 | # dark5 153 | self.dark5 = nn.Sequential( 154 | Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), 155 | SPPBottleneck(base_channels * 16, base_channels * 16, activation=act), 156 | CSPLayer( 157 | base_channels * 16, 158 | base_channels * 16, 159 | n=base_depth, 160 | shortcut=False, 161 | depthwise=depthwise, 162 | act=act, 163 | ), 164 | ) 165 | 166 | def forward(self, x): 167 | outputs = {} 168 | x = self.stem(x) 169 | outputs["stem"] = x 170 | x = self.dark2(x) 171 | outputs["dark2"] = x 172 | x = self.dark3(x) 173 | outputs["dark3"] = x 174 | x = self.dark4(x) 175 | outputs["dark4"] = x 176 | x = self.dark5(x) 177 | outputs["dark5"] = x 178 | return {k: v for k, v in outputs.items() if k in self.out_features} 179 | -------------------------------------------------------------------------------- /transvcl/models/linear_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import torch 5 | from torch.nn import Module, Dropout 6 | 7 | 8 | def elu_feature_map(x): 9 | return torch.nn.functional.elu(x) + 1 10 | 11 | 12 | class LinearAttention(Module): 13 | def __init__(self, eps=1e-6): 14 | super().__init__() 15 | self.feature_map = elu_feature_map 16 | self.eps = eps 17 | 18 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 19 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 20 | Args: 21 | queries: [N, L, H, D] 22 | keys: [N, S, H, D] 23 | values: [N, S, H, D] 24 | q_mask: [N, L] 25 | kv_mask: [N, S] 26 | Returns: 27 | queried_values: (N, L, H, D) 28 | """ 29 | Q = self.feature_map(queries) 30 | K = self.feature_map(keys) 31 | 32 | # set padded position to zero 33 | if q_mask is not None: 34 | Q = Q * q_mask[:, :, None, None] 35 | if kv_mask is not None: 36 | K = K * kv_mask[:, :, None, None] 37 | values = values * kv_mask[:, :, None, None] 38 | 39 | v_length = values.size(1) 40 | values = values / v_length # prevent fp16 overflow 41 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 42 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 43 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 44 | 45 | return queried_values.contiguous() 46 | 47 | 48 | class FullAttention(Module): 49 | def __init__(self, use_dropout=False, attention_dropout=0.1): 50 | super().__init__() 51 | self.use_dropout = use_dropout 52 | self.dropout = Dropout(attention_dropout) 53 | 54 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 55 | """ Multi-head scaled dot-product attention, a.k.a full attention. 56 | Args: 57 | queries: [N, L, H, D] 58 | keys: [N, S, H, D] 59 | values: [N, S, H, D] 60 | q_mask: [N, L] 61 | kv_mask: [N, S] 62 | Returns: 63 | queried_values: (N, L, H, D) 64 | """ 65 | 66 | # Compute the unnormalized attention and apply the masks 67 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) 68 | if kv_mask is not None: 69 | QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) 70 | 71 | # Compute the attention and the weighted average 72 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) 73 | A = torch.softmax(softmax_temp * QK, dim=2) 74 | if self.use_dropout: 75 | A = self.dropout(A) 76 | 77 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) 78 | 79 | return queried_values.contiguous() 80 | -------------------------------------------------------------------------------- /transvcl/models/network_blocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class SiLU(nn.Module): 9 | """export-friendly version of nn.SiLU()""" 10 | 11 | @staticmethod 12 | def forward(x): 13 | return x * torch.sigmoid(x) 14 | 15 | 16 | def get_activation(name="silu", inplace=True): 17 | if name == "silu": 18 | module = nn.SiLU(inplace=inplace) 19 | elif name == "relu": 20 | module = nn.ReLU(inplace=inplace) 21 | elif name == "lrelu": 22 | module = nn.LeakyReLU(0.1, inplace=inplace) 23 | else: 24 | raise AttributeError("Unsupported act type: {}".format(name)) 25 | return module 26 | 27 | 28 | class BaseConv(nn.Module): 29 | """A Conv2d -> Batchnorm -> silu/leaky relu block""" 30 | 31 | def __init__( 32 | self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu" 33 | ): 34 | super().__init__() 35 | # same padding 36 | pad = (ksize - 1) // 2 37 | self.conv = nn.Conv2d( 38 | in_channels, 39 | out_channels, 40 | kernel_size=ksize, 41 | stride=stride, 42 | padding=pad, 43 | groups=groups, 44 | bias=bias, 45 | ) 46 | self.bn = nn.BatchNorm2d(out_channels) 47 | self.act = get_activation(act, inplace=True) 48 | 49 | def forward(self, x): 50 | return self.act(self.bn(self.conv(x))) 51 | 52 | def fuseforward(self, x): 53 | return self.act(self.conv(x)) 54 | 55 | 56 | class DWConv(nn.Module): 57 | """Depthwise Conv + Conv""" 58 | 59 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): 60 | super().__init__() 61 | self.dconv = BaseConv( 62 | in_channels, 63 | in_channels, 64 | ksize=ksize, 65 | stride=stride, 66 | groups=in_channels, 67 | act=act, 68 | ) 69 | self.pconv = BaseConv( 70 | in_channels, out_channels, ksize=1, stride=1, groups=1, act=act 71 | ) 72 | 73 | def forward(self, x): 74 | x = self.dconv(x) 75 | return self.pconv(x) 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Standard bottleneck 80 | def __init__( 81 | self, 82 | in_channels, 83 | out_channels, 84 | shortcut=True, 85 | expansion=0.5, 86 | depthwise=False, 87 | act="silu", 88 | ): 89 | super().__init__() 90 | hidden_channels = int(out_channels * expansion) 91 | Conv = DWConv if depthwise else BaseConv 92 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 93 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) 94 | self.use_add = shortcut and in_channels == out_channels 95 | 96 | def forward(self, x): 97 | y = self.conv2(self.conv1(x)) 98 | if self.use_add: 99 | y = y + x 100 | return y 101 | 102 | 103 | class ResLayer(nn.Module): 104 | "Residual layer with `in_channels` inputs." 105 | 106 | def __init__(self, in_channels: int): 107 | super().__init__() 108 | mid_channels = in_channels // 2 109 | self.layer1 = BaseConv( 110 | in_channels, mid_channels, ksize=1, stride=1, act="lrelu" 111 | ) 112 | self.layer2 = BaseConv( 113 | mid_channels, in_channels, ksize=3, stride=1, act="lrelu" 114 | ) 115 | 116 | def forward(self, x): 117 | out = self.layer2(self.layer1(x)) 118 | return x + out 119 | 120 | 121 | class SPPBottleneck(nn.Module): 122 | """Spatial pyramid pooling layer used in YOLOv3-SPP""" 123 | 124 | def __init__( 125 | self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu" 126 | ): 127 | super().__init__() 128 | hidden_channels = in_channels // 2 129 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) 130 | self.m = nn.ModuleList( 131 | [ 132 | nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) 133 | for ks in kernel_sizes 134 | ] 135 | ) 136 | conv2_channels = hidden_channels * (len(kernel_sizes) + 1) 137 | self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = torch.cat([x] + [m(x) for m in self.m], dim=1) 142 | x = self.conv2(x) 143 | return x 144 | 145 | 146 | class CSPLayer(nn.Module): 147 | """C3 in yolov5, CSP Bottleneck with 3 convolutions""" 148 | 149 | def __init__( 150 | self, 151 | in_channels, 152 | out_channels, 153 | n=1, 154 | shortcut=True, 155 | expansion=0.5, 156 | depthwise=False, 157 | act="silu", 158 | ): 159 | """ 160 | Args: 161 | in_channels (int): input channels. 162 | out_channels (int): output channels. 163 | n (int): number of Bottlenecks. Default value: 1. 164 | """ 165 | # ch_in, ch_out, number, shortcut, groups, expansion 166 | super().__init__() 167 | hidden_channels = int(out_channels * expansion) # hidden channels 168 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 169 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 170 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) 171 | module_list = [ 172 | Bottleneck( 173 | hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act 174 | ) 175 | for _ in range(n) 176 | ] 177 | self.m = nn.Sequential(*module_list) 178 | 179 | def forward(self, x): 180 | x_1 = self.conv1(x) 181 | x_2 = self.conv2(x) 182 | x_1 = self.m(x_1) 183 | x = torch.cat((x_1, x_2), dim=1) 184 | return self.conv3(x) 185 | 186 | 187 | class Focus(nn.Module): 188 | """Focus width and height information into channel space.""" 189 | 190 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): 191 | super().__init__() 192 | self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) 193 | 194 | def forward(self, x): 195 | # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) 196 | patch_top_left = x[..., ::2, ::2] 197 | patch_top_right = x[..., ::2, 1::2] 198 | patch_bot_left = x[..., 1::2, ::2] 199 | patch_bot_right = x[..., 1::2, 1::2] 200 | x = torch.cat( 201 | ( 202 | patch_top_left, 203 | patch_bot_left, 204 | patch_top_right, 205 | patch_bot_right, 206 | ), 207 | dim=1, 208 | ) 209 | return self.conv(x) 210 | -------------------------------------------------------------------------------- /transvcl/models/sim_gen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import copy 7 | import math 8 | from torch.autograd import Variable 9 | from .linear_attention import LinearAttention, FullAttention 10 | import numpy as np 11 | import cv2 12 | 13 | 14 | INF = 1e9 15 | 16 | class EncoderLayer(nn.Module): 17 | def __init__(self, 18 | d_model, 19 | nhead, 20 | attention='linear'): 21 | super(EncoderLayer, self).__init__() 22 | 23 | self.dim = d_model // nhead 24 | self.nhead = nhead 25 | 26 | # multi-head attention 27 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 28 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 29 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 30 | self.attention = LinearAttention() if attention == 'linear' else FullAttention() 31 | self.merge = nn.Linear(d_model, d_model, bias=False) 32 | 33 | # feed-forward network 34 | self.mlp = nn.Sequential( 35 | nn.Linear(d_model*2, d_model*2, bias=False), 36 | nn.ReLU(True), 37 | nn.Linear(d_model*2, d_model, bias=False), 38 | ) 39 | 40 | # norm and dropout 41 | self.norm1 = nn.LayerNorm(d_model) 42 | self.norm2 = nn.LayerNorm(d_model) 43 | 44 | def forward(self, x, source, x_mask=None, source_mask=None, file=None): 45 | """ 46 | Args: 47 | x (torch.Tensor): [N, L, C] 48 | source (torch.Tensor): [N, S, C] 49 | x_mask (torch.Tensor): [N, L] (optional) 50 | source_mask (torch.Tensor): [N, S] (optional) 51 | """ 52 | bs = x.size(0) 53 | query, key, value = x, source, source 54 | 55 | # multi-head attention 56 | query = self.q_proj(query) 57 | key = self.k_proj(key) 58 | query = query.view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 59 | key = key.view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] 60 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 61 | message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] 62 | message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] 63 | message = self.norm1(message) 64 | 65 | # feed-forward network 66 | message = self.mlp(torch.cat([x, message], dim=2)) 67 | message = self.norm2(message) 68 | 69 | return x + message 70 | 71 | class PositionalEncoding(nn.Module): 72 | "Implement the PE function." 73 | def __init__(self, d_model, dropout=0.5, max_len=1200): 74 | super(PositionalEncoding, self).__init__() 75 | self.dropout = nn.Dropout(p=dropout) 76 | 77 | # Compute the positional encodings once in log space. 78 | pe = torch.zeros(max_len, d_model) 79 | position = torch.arange(0, max_len).unsqueeze(1) 80 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 81 | (-math.log(10000.0) / d_model)) 82 | pe[:, 0::2] = torch.sin(position * div_term) 83 | pe[:, 1::2] = torch.cos(position * div_term) 84 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) 85 | 86 | def forward(self, x): 87 | """ 88 | Args: 89 | x (torch.Tensor): [N, L, C] 90 | """ 91 | _, _, d = x.shape 92 | return x + d ** -0.5 * self.pe[:, :x.size(1)] 93 | 94 | class FeatureTransformer(nn.Module): 95 | """Feature Transformer module.""" 96 | 97 | def __init__(self, config): 98 | super(FeatureTransformer, self).__init__() 99 | 100 | self.config = config 101 | self.d_model = config['d_model'] 102 | self.nhead = config['nhead'] 103 | self.layer_names = config['layer_names'] 104 | encoder_layer = EncoderLayer(config['d_model'], config['nhead'], config['attention']) 105 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) 106 | self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model)) 107 | self._reset_parameters() 108 | 109 | def _reset_parameters(self): 110 | for p in self.parameters(): 111 | if p.dim() > 1: 112 | nn.init.xavier_uniform_(p) 113 | 114 | def forward(self, feat0, feat1, mask0=None, mask1=None, cls_token=False, file=None): 115 | """ 116 | Args: 117 | feat0 (torch.Tensor): [N, L, C] 118 | feat1 (torch.Tensor): [N, S, C] 119 | mask0 (torch.Tensor): [N, L] (optional) 120 | mask1 (torch.Tensor): [N, S] (optional) 121 | """ 122 | if cls_token: 123 | N, L, _ = feat0.shape 124 | _, S, _ = feat1.shape 125 | cls_token_0 = self.cls_token.repeat(N, 1, 1) 126 | cls_token_1 = self.cls_token.repeat(N, 1, 1) 127 | feat0 = torch.cat((cls_token_0, feat0), dim=1) 128 | feat1 = torch.cat((cls_token_1, feat1), dim=1) 129 | mask0 = torch.cat((torch.tensor([1]).repeat(N, 1).to(mask0.device), mask0), dim=1) 130 | mask1 = torch.cat((torch.tensor([1]).repeat(N, 1).to(mask1.device), mask1), dim=1) 131 | for layer, name in zip(self.layers, self.layer_names): 132 | if name == 'self': 133 | feat0 = layer(feat0, feat0, mask0, mask0, file[0] + '_self_feat0') 134 | feat1 = layer(feat1, feat1, mask1, mask1, file[0] + '_self_feat1') 135 | elif name == 'cross': 136 | feat0 = layer(feat0, feat1, mask0, mask1, file[0] + '_cross_feat0') 137 | feat1 = layer(feat1, feat0, mask1, mask0, file[0] + '_cross_feat1') 138 | else: 139 | raise KeyError 140 | 141 | return feat0, feat1 142 | 143 | class SimMapGen(nn.Module): 144 | def __init__(self, config): 145 | super().__init__() 146 | self.config = config 147 | 148 | # we provide 2 options for differentiable matching 149 | self.match_type = config['match_type'] 150 | if self.match_type == 'dual_softmax': 151 | self.temperature = config['dsmax_temperature'] 152 | elif self.match_type == 'sinkhorn': 153 | try: 154 | from .superglue import log_optimal_transport 155 | except ImportError: 156 | raise ImportError("download superglue.py first!") 157 | self.log_optimal_transport = log_optimal_transport 158 | self.bin_score = nn.Parameter( 159 | torch.tensor(config['skh_init_bin_score'], requires_grad=True)) 160 | self.skh_iters = config['skh_iters'] 161 | self.skh_prefilter = config['skh_prefilter'] 162 | else: 163 | raise NotImplementedError() 164 | 165 | def forward(self, feat_c0, feat_c1, mask_c0=None, mask_c1=None): 166 | """ 167 | Args: 168 | feat0 (torch.Tensor): [N, L, C] 169 | feat1 (torch.Tensor): [N, S, C] 170 | mask_c0 (torch.Tensor): [N, L] (optional) 171 | mask_c1 (torch.Tensor): [N, S] (optional) 172 | Update: 173 | conf_matrix (torch.Tensor): [N, L, S] 174 | """ 175 | N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) 176 | 177 | # normalize 178 | feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, 179 | [feat_c0, feat_c1]) 180 | 181 | if self.match_type == 'dual_softmax': 182 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, 183 | feat_c1) / self.temperature 184 | if mask_c0 is not None: 185 | sim_matrix.masked_fill_( 186 | ~(mask_c0[..., None] * mask_c1[:, None]).bool(), 187 | 0) 188 | # if mask_c0 is not None: 189 | # sim_matrix = torch.where((mask_c0.long()[..., None] * mask_c1.long()[:, None]).bool(), sim_matrix, torch.tensor([0.])) 190 | conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) 191 | 192 | else: 193 | assert self.match_type == 'sinkhorn', f'match type {self.match_type} is neither sinkhorn nor dual_softmax' 194 | # sinkhorn, dustbin included 195 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) 196 | if mask_c0 is not None: 197 | sim_matrix[:, :L, :S].masked_fill_( 198 | ~(mask_c0[..., None] * mask_c1[:, None]).bool(), 199 | -INF) 200 | 201 | # build uniform prior & use sinkhorn 202 | log_assign_matrix = self.log_optimal_transport( 203 | sim_matrix, self.bin_score, self.skh_iters) 204 | assign_matrix = log_assign_matrix.exp() 205 | conf_matrix = assign_matrix[:, :-1, :-1] 206 | 207 | # filter prediction with dustbin score (only in evaluation mode) 208 | if not self.training and self.skh_prefilter: 209 | filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] 210 | filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] 211 | conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 212 | conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 213 | 214 | return conf_matrix 215 | -------------------------------------------------------------------------------- /transvcl/models/transvcl_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | from .yolo_head import YOLOXHead 6 | from .yolo_pafpn import YOLOPAFPN 7 | from .sim_gen import FeatureTransformer, SimMapGen, PositionalEncoding 8 | import torch.nn.functional as F 9 | import torch 10 | 11 | class TransVCL(nn.Module): 12 | 13 | def __init__(self, feat_transformer_config, backbone=None, head=None): 14 | super().__init__() 15 | if backbone is None: 16 | backbone = YOLOPAFPN() 17 | if head is None: 18 | head = YOLOXHead(1) 19 | 20 | self.pos_encoding = PositionalEncoding(feat_transformer_config['d_model']) 21 | self.feat_encoding = FeatureTransformer(feat_transformer_config) 22 | self.sim_gen = SimMapGen(feat_transformer_config) 23 | self.mlp_head = nn.Sequential( 24 | nn.Linear(feat_transformer_config['d_model'] * 2, feat_transformer_config['d_model']), 25 | nn.ReLU(), 26 | nn.LayerNorm(feat_transformer_config['d_model']), 27 | nn.Linear(feat_transformer_config['d_model'], 1) 28 | ) 29 | 30 | self.backbone = backbone 31 | self.head = head 32 | 33 | self.keep_ratio = feat_transformer_config['keep_ratio'] 34 | self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") 35 | self.unsupervised_weight = feat_transformer_config['unsupervised_weight'] 36 | 37 | 38 | def forward(self, feat1, feat2, mask1=None, mask2=None, targets=None, id_=None, input_size=(640, 640), global_loss=True, vis=False): 39 | feat1 = self.pos_encoding(feat1) 40 | feat2 = self.pos_encoding(feat2) 41 | if global_loss: 42 | feat1, feat2 = self.feat_encoding(feat1, feat2, mask1, mask2, cls_token=True, file=targets) 43 | global_feat1 = feat1[:, 0] 44 | global_feat2 = feat2[:, 0] 45 | global_pred = self.mlp_head(torch.cat((global_feat1, global_feat2), dim=1)) 46 | feat1, feat2 = feat1[:, 1:], feat2[:, 1:] 47 | else: 48 | feat1, feat2 = self.feat_encoding(feat1, feat2, mask1, mask2) 49 | 50 | sim_matrix = self.sim_gen(feat1, feat2, mask1, mask2) 51 | x, global_label = [], [] 52 | for idx, m in enumerate(sim_matrix): 53 | m = m[:mask1[idx].sum().int(), :mask2[idx].sum().int()][..., None] 54 | if self.training: 55 | if targets[idx].max() < 1e-6: 56 | global_label.append(torch.tensor([0.], device=m.device)) 57 | else: 58 | global_label.append(torch.tensor([1.], device=m.device)) 59 | m = m.repeat(1, 1, 1, 3).permute(0, 3, 1, 2) 60 | m = F.interpolate(m, size=input_size) 61 | x.append(m) 62 | x = torch.cat(x, dim=0) 63 | 64 | if global_loss and self.training: 65 | weight = id_.clone() 66 | weight[id_ < self.sum_index] = 1. 67 | weight[id_ >= self.sum_index] = self.unsupervised_weight 68 | weight = weight.float() 69 | global_label = torch.cat(global_label, dim=0).view(-1, 1) 70 | # global_pred = torch.cat(global_pred, dim=0).view(-1, 1) 71 | global_loss = torch.squeeze(self.bcewithlog_loss(global_pred, global_label)) 72 | global_loss = torch.matmul(global_loss, weight) 73 | else: 74 | global_loss = 0 75 | 76 | fpn_outs = self.backbone(x) 77 | 78 | if self.training: 79 | assert targets is not None 80 | if sum(id_ < self.sum_index) > 0: 81 | fpn_outs_supervised = tuple([fpn_out[id_ < self.sum_index, ...] for fpn_out in fpn_outs]) 82 | targets_supervised = targets[id_ < self.sum_index, ...] 83 | x_supervised = x[id_ < self.sum_index, ...] 84 | loss_supervised, iou_loss_supervised, conf_loss_supervised, cls_loss_supervised, l1_loss_supervised, num_fg_supervised = self.head( 85 | fpn_outs_supervised, targets_supervised, x_supervised 86 | ) 87 | else: 88 | loss_supervised, iou_loss_supervised, conf_loss_supervised, cls_loss_supervised, l1_loss_supervised, num_fg_supervised = 0, 0, 0, 0, 0, 0 89 | if sum(id_ >= self.sum_index) > 0: 90 | fpn_outs_unsupervised = tuple([fpn_out[id_ >= self.sum_index, ...] for fpn_out in fpn_outs]) 91 | targets_unsupervised = targets[id_ >= self.sum_index, ...] 92 | x_unsupervised = x[id_ >= self.sum_index, ...] 93 | loss_unsupervised, iou_loss_unsupervised, conf_loss_unsupervised, cls_loss_unsupervised, l1_loss_unsupervised, num_fg_unsupervised = self.head( 94 | fpn_outs_unsupervised, targets_unsupervised, x_unsupervised 95 | ) 96 | else: 97 | loss_unsupervised, iou_loss_unsupervised, conf_loss_unsupervised, cls_loss_unsupervised, l1_loss_unsupervised, num_fg_unsupervised = 0, 0, 0, 0, 0, 0 98 | outputs = { 99 | "total_loss": loss_supervised + self.unsupervised_weight * loss_unsupervised + 0.01 * global_loss, 100 | "iou_loss": loss_supervised + self.unsupervised_weight * loss_unsupervised, 101 | "l1_loss": l1_loss_supervised + self.unsupervised_weight * l1_loss_unsupervised, 102 | "conf_loss": conf_loss_supervised + self.unsupervised_weight * conf_loss_unsupervised, 103 | "cls_loss": cls_loss_supervised + self.unsupervised_weight * cls_loss_unsupervised, 104 | "global_loss": global_loss, 105 | "num_fg": num_fg_supervised + self.unsupervised_weight * num_fg_unsupervised, 106 | } 107 | else: 108 | outputs = self.head(fpn_outs) 109 | outputs = [global_pred, outputs] 110 | 111 | return outputs 112 | -------------------------------------------------------------------------------- /transvcl/models/yolo_head.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from transvcl.utils import bboxes_iou 10 | from .network_blocks import BaseConv, DWConv 11 | 12 | 13 | class YOLOXHead(nn.Module): 14 | def __init__( 15 | self, 16 | num_classes, 17 | width=1.0, 18 | strides=[8, 16, 32], 19 | in_channels=[256, 512, 1024], 20 | act="silu", 21 | depthwise=False, 22 | ): 23 | """ 24 | Args: 25 | act (str): activation type of conv. Defalut value: "silu". 26 | depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. 27 | """ 28 | super().__init__() 29 | 30 | self.n_anchors = 1 31 | self.num_classes = num_classes 32 | self.decode_in_inference = True # for deploy, set to False 33 | 34 | self.cls_convs = nn.ModuleList() 35 | self.reg_convs = nn.ModuleList() 36 | self.cls_preds = nn.ModuleList() 37 | self.reg_preds = nn.ModuleList() 38 | self.obj_preds = nn.ModuleList() 39 | self.stems = nn.ModuleList() 40 | Conv = DWConv if depthwise else BaseConv 41 | 42 | for i in range(len(in_channels)): 43 | self.stems.append( 44 | BaseConv( 45 | in_channels=int(in_channels[i] * width), 46 | out_channels=int(256 * width), 47 | ksize=1, 48 | stride=1, 49 | act=act, 50 | ) 51 | ) 52 | self.cls_convs.append( 53 | nn.Sequential( 54 | *[ 55 | Conv( 56 | in_channels=int(256 * width), 57 | out_channels=int(256 * width), 58 | ksize=3, 59 | stride=1, 60 | act=act, 61 | ), 62 | Conv( 63 | in_channels=int(256 * width), 64 | out_channels=int(256 * width), 65 | ksize=3, 66 | stride=1, 67 | act=act, 68 | ), 69 | ] 70 | ) 71 | ) 72 | self.reg_convs.append( 73 | nn.Sequential( 74 | *[ 75 | Conv( 76 | in_channels=int(256 * width), 77 | out_channels=int(256 * width), 78 | ksize=3, 79 | stride=1, 80 | act=act, 81 | ), 82 | Conv( 83 | in_channels=int(256 * width), 84 | out_channels=int(256 * width), 85 | ksize=3, 86 | stride=1, 87 | act=act, 88 | ), 89 | ] 90 | ) 91 | ) 92 | self.cls_preds.append( 93 | nn.Conv2d( 94 | in_channels=int(256 * width), 95 | out_channels=self.n_anchors * self.num_classes, 96 | kernel_size=1, 97 | stride=1, 98 | padding=0, 99 | ) 100 | ) 101 | self.reg_preds.append( 102 | nn.Conv2d( 103 | in_channels=int(256 * width), 104 | out_channels=4, 105 | kernel_size=1, 106 | stride=1, 107 | padding=0, 108 | ) 109 | ) 110 | self.obj_preds.append( 111 | nn.Conv2d( 112 | in_channels=int(256 * width), 113 | out_channels=self.n_anchors * 1, 114 | kernel_size=1, 115 | stride=1, 116 | padding=0, 117 | ) 118 | ) 119 | 120 | self.use_l1 = False 121 | self.l1_loss = nn.L1Loss(reduction="none") 122 | self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") 123 | self.iou_loss = IOUloss(reduction="none") 124 | self.strides = strides 125 | self.grids = [torch.zeros(1)] * len(in_channels) 126 | 127 | def initialize_biases(self, prior_prob): 128 | for conv in self.cls_preds: 129 | b = conv.bias.view(self.n_anchors, -1) 130 | b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) 131 | conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) 132 | 133 | for conv in self.obj_preds: 134 | b = conv.bias.view(self.n_anchors, -1) 135 | b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) 136 | conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) 137 | 138 | def forward(self, xin, labels=None, imgs=None): 139 | outputs = [] 140 | origin_preds = [] 141 | x_shifts = [] 142 | y_shifts = [] 143 | expanded_strides = [] 144 | 145 | for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( 146 | zip(self.cls_convs, self.reg_convs, self.strides, xin) 147 | ): 148 | x = self.stems[k](x) 149 | cls_x = x 150 | reg_x = x 151 | 152 | cls_feat = cls_conv(cls_x) 153 | cls_output = self.cls_preds[k](cls_feat) 154 | 155 | reg_feat = reg_conv(reg_x) 156 | reg_output = self.reg_preds[k](reg_feat) 157 | obj_output = self.obj_preds[k](reg_feat) 158 | 159 | if self.training: 160 | output = torch.cat([reg_output, obj_output, cls_output], 1) 161 | output, grid = self.get_output_and_grid( 162 | output, k, stride_this_level, xin[0].type() 163 | ) 164 | x_shifts.append(grid[:, :, 0]) 165 | y_shifts.append(grid[:, :, 1]) 166 | expanded_strides.append( 167 | torch.zeros(1, grid.shape[1]) 168 | .fill_(stride_this_level) 169 | .type_as(xin[0]) 170 | ) 171 | if self.use_l1: 172 | batch_size = reg_output.shape[0] 173 | hsize, wsize = reg_output.shape[-2:] 174 | reg_output = reg_output.view( 175 | batch_size, self.n_anchors, 4, hsize, wsize 176 | ) 177 | reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape( 178 | batch_size, -1, 4 179 | ) 180 | origin_preds.append(reg_output.clone()) 181 | 182 | else: 183 | output = torch.cat( 184 | [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 185 | ) 186 | 187 | outputs.append(output) 188 | 189 | if self.training: 190 | return self.get_losses( 191 | imgs, 192 | x_shifts, 193 | y_shifts, 194 | expanded_strides, 195 | labels, 196 | torch.cat(outputs, 1), 197 | origin_preds, 198 | dtype=xin[0].dtype, 199 | ) 200 | else: 201 | self.hw = [x.shape[-2:] for x in outputs] 202 | # [batch, n_anchors_all, 85] 203 | outputs = torch.cat( 204 | [x.flatten(start_dim=2) for x in outputs], dim=2 205 | ).permute(0, 2, 1) 206 | if self.decode_in_inference: 207 | return self.decode_outputs(outputs, dtype=xin[0].type(), device=xin[0].device) 208 | else: 209 | return outputs 210 | 211 | def get_output_and_grid(self, output, k, stride, dtype): 212 | grid = self.grids[k] 213 | 214 | batch_size = output.shape[0] 215 | n_ch = 5 + self.num_classes 216 | hsize, wsize = output.shape[-2:] 217 | if grid.shape[2:4] != output.shape[2:4]: 218 | yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) 219 | grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) 220 | self.grids[k] = grid 221 | 222 | output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize) 223 | output = output.permute(0, 1, 3, 4, 2).reshape( 224 | batch_size, self.n_anchors * hsize * wsize, -1 225 | ) 226 | grid = grid.view(1, -1, 2) 227 | output[..., :2] = (output[..., :2] + grid) * stride 228 | output[..., 2:4] = torch.exp(output[..., 2:4]) * stride 229 | return output, grid 230 | 231 | def decode_outputs(self, outputs, dtype, device): 232 | grids = [] 233 | strides = [] 234 | for (hsize, wsize), stride in zip(self.hw, self.strides): 235 | yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) 236 | grid = torch.stack((xv, yv), 2).view(1, -1, 2) 237 | grids.append(grid) 238 | shape = grid.shape[:2] 239 | strides.append(torch.full((*shape, 1), stride)) 240 | 241 | grids = torch.cat(grids, dim=1).type(dtype).to(device) 242 | strides = torch.cat(strides, dim=1).type(dtype).to(device) 243 | 244 | outputs[..., :2] = (outputs[..., :2] + grids) * strides 245 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 246 | return outputs 247 | 248 | def get_losses( 249 | self, 250 | imgs, 251 | x_shifts, 252 | y_shifts, 253 | expanded_strides, 254 | labels, 255 | outputs, 256 | origin_preds, 257 | dtype, 258 | ): 259 | bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4] 260 | obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1] 261 | cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] 262 | 263 | # calculate targets 264 | nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects 265 | 266 | total_num_anchors = outputs.shape[1] 267 | x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] 268 | y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] 269 | expanded_strides = torch.cat(expanded_strides, 1) 270 | if self.use_l1: 271 | origin_preds = torch.cat(origin_preds, 1) 272 | 273 | cls_targets = [] 274 | reg_targets = [] 275 | l1_targets = [] 276 | obj_targets = [] 277 | fg_masks = [] 278 | 279 | num_fg = 0.0 280 | num_gts = 0.0 281 | 282 | for batch_idx in range(outputs.shape[0]): 283 | num_gt = int(nlabel[batch_idx]) 284 | num_gts += num_gt 285 | if num_gt == 0: 286 | cls_target = outputs.new_zeros((0, self.num_classes)) 287 | reg_target = outputs.new_zeros((0, 4)) 288 | l1_target = outputs.new_zeros((0, 4)) 289 | obj_target = outputs.new_zeros((total_num_anchors, 1)) 290 | fg_mask = outputs.new_zeros(total_num_anchors).bool() 291 | else: 292 | gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] 293 | gt_classes = labels[batch_idx, :num_gt, 0] 294 | bboxes_preds_per_image = bbox_preds[batch_idx] 295 | 296 | 297 | ( 298 | gt_matched_classes, 299 | fg_mask, 300 | pred_ious_this_matching, 301 | matched_gt_inds, 302 | num_fg_img, 303 | ) = self.get_assignments( # noqa 304 | batch_idx, 305 | num_gt, 306 | total_num_anchors, 307 | gt_bboxes_per_image, 308 | gt_classes, 309 | bboxes_preds_per_image, 310 | expanded_strides, 311 | x_shifts, 312 | y_shifts, 313 | cls_preds, 314 | bbox_preds, 315 | obj_preds, 316 | labels, 317 | imgs, 318 | ) 319 | 320 | 321 | torch.cuda.empty_cache() 322 | num_fg += num_fg_img 323 | 324 | cls_target = F.one_hot( 325 | gt_matched_classes.to(torch.int64), self.num_classes 326 | ) * pred_ious_this_matching.unsqueeze(-1) 327 | obj_target = fg_mask.unsqueeze(-1) 328 | reg_target = gt_bboxes_per_image[matched_gt_inds] 329 | if self.use_l1: 330 | l1_target = self.get_l1_target( 331 | outputs.new_zeros((num_fg_img, 4)), 332 | gt_bboxes_per_image[matched_gt_inds], 333 | expanded_strides[0][fg_mask], 334 | x_shifts=x_shifts[0][fg_mask], 335 | y_shifts=y_shifts[0][fg_mask], 336 | ) 337 | 338 | cls_targets.append(cls_target) 339 | reg_targets.append(reg_target) 340 | obj_targets.append(obj_target.to(dtype)) 341 | fg_masks.append(fg_mask) 342 | if self.use_l1: 343 | l1_targets.append(l1_target) 344 | 345 | cls_targets = torch.cat(cls_targets, 0) 346 | reg_targets = torch.cat(reg_targets, 0) 347 | obj_targets = torch.cat(obj_targets, 0) 348 | fg_masks = torch.cat(fg_masks, 0) 349 | if self.use_l1: 350 | l1_targets = torch.cat(l1_targets, 0) 351 | 352 | num_fg = max(num_fg, 1) 353 | loss_iou = ( 354 | self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets) 355 | ).sum() / num_fg 356 | loss_obj = ( 357 | self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets) 358 | ).sum() / num_fg 359 | loss_cls = ( 360 | self.bcewithlog_loss( 361 | cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets 362 | ) 363 | ).sum() / num_fg 364 | if self.use_l1: 365 | loss_l1 = ( 366 | self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets) 367 | ).sum() / num_fg 368 | else: 369 | loss_l1 = 0.0 370 | 371 | reg_weight = 5.0 372 | loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 373 | 374 | return ( 375 | loss, 376 | reg_weight * loss_iou, 377 | loss_obj, 378 | loss_cls, 379 | loss_l1, 380 | num_fg / max(num_gts, 1), 381 | ) 382 | 383 | def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8): 384 | l1_target[:, 0] = gt[:, 0] / stride - x_shifts 385 | l1_target[:, 1] = gt[:, 1] / stride - y_shifts 386 | l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps) 387 | l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps) 388 | return l1_target 389 | 390 | @torch.no_grad() 391 | def get_assignments( 392 | self, 393 | batch_idx, 394 | num_gt, 395 | total_num_anchors, 396 | gt_bboxes_per_image, 397 | gt_classes, 398 | bboxes_preds_per_image, 399 | expanded_strides, 400 | x_shifts, 401 | y_shifts, 402 | cls_preds, 403 | bbox_preds, 404 | obj_preds, 405 | labels, 406 | imgs, 407 | mode="gpu", 408 | ): 409 | 410 | if mode == "cpu": 411 | print("------------CPU Mode for This Batch-------------") 412 | gt_bboxes_per_image = gt_bboxes_per_image.cpu().float() 413 | bboxes_preds_per_image = bboxes_preds_per_image.cpu().float() 414 | gt_classes = gt_classes.cpu().float() 415 | expanded_strides = expanded_strides.cpu().float() 416 | x_shifts = x_shifts.cpu() 417 | y_shifts = y_shifts.cpu() 418 | 419 | fg_mask, is_in_boxes_and_center = self.get_in_boxes_info( 420 | gt_bboxes_per_image, 421 | expanded_strides, 422 | x_shifts, 423 | y_shifts, 424 | total_num_anchors, 425 | num_gt, 426 | ) 427 | 428 | bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] 429 | cls_preds_ = cls_preds[batch_idx][fg_mask] 430 | obj_preds_ = obj_preds[batch_idx][fg_mask] 431 | num_in_boxes_anchor = bboxes_preds_per_image.shape[0] 432 | 433 | if mode == "cpu": 434 | gt_bboxes_per_image = gt_bboxes_per_image.cpu() 435 | bboxes_preds_per_image = bboxes_preds_per_image.cpu() 436 | 437 | pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) 438 | 439 | gt_cls_per_image = ( 440 | F.one_hot(gt_classes.to(torch.int64), self.num_classes) 441 | .float() 442 | .unsqueeze(1) 443 | .repeat(1, num_in_boxes_anchor, 1) 444 | ) 445 | pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) 446 | 447 | if mode == "cpu": 448 | cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu() 449 | 450 | with torch.cuda.amp.autocast(enabled=False): 451 | cls_preds_ = ( 452 | cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 453 | * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 454 | ) 455 | pair_wise_cls_loss = F.binary_cross_entropy( 456 | cls_preds_.sqrt_(), gt_cls_per_image, reduction="none" 457 | ).sum(-1) 458 | del cls_preds_ 459 | 460 | cost = ( 461 | pair_wise_cls_loss 462 | + 3.0 * pair_wise_ious_loss 463 | + 100000.0 * (~is_in_boxes_and_center) 464 | ) 465 | 466 | ( 467 | num_fg, 468 | gt_matched_classes, 469 | pred_ious_this_matching, 470 | matched_gt_inds, 471 | ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) 472 | del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss 473 | 474 | if mode == "cpu": 475 | gt_matched_classes = gt_matched_classes.cuda() 476 | fg_mask = fg_mask.cuda() 477 | pred_ious_this_matching = pred_ious_this_matching.cuda() 478 | matched_gt_inds = matched_gt_inds.cuda() 479 | 480 | return ( 481 | gt_matched_classes, 482 | fg_mask, 483 | pred_ious_this_matching, 484 | matched_gt_inds, 485 | num_fg, 486 | ) 487 | 488 | def get_in_boxes_info( 489 | self, 490 | gt_bboxes_per_image, 491 | expanded_strides, 492 | x_shifts, 493 | y_shifts, 494 | total_num_anchors, 495 | num_gt, 496 | ): 497 | expanded_strides_per_image = expanded_strides[0] 498 | x_shifts_per_image = x_shifts[0] * expanded_strides_per_image 499 | y_shifts_per_image = y_shifts[0] * expanded_strides_per_image 500 | x_centers_per_image = ( 501 | (x_shifts_per_image + 0.5 * expanded_strides_per_image) 502 | .unsqueeze(0) 503 | .repeat(num_gt, 1) 504 | ) # [n_anchor] -> [n_gt, n_anchor] 505 | y_centers_per_image = ( 506 | (y_shifts_per_image + 0.5 * expanded_strides_per_image) 507 | .unsqueeze(0) 508 | .repeat(num_gt, 1) 509 | ) 510 | 511 | gt_bboxes_per_image_l = ( 512 | (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]) 513 | .unsqueeze(1) 514 | .repeat(1, total_num_anchors) 515 | ) 516 | gt_bboxes_per_image_r = ( 517 | (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]) 518 | .unsqueeze(1) 519 | .repeat(1, total_num_anchors) 520 | ) 521 | gt_bboxes_per_image_t = ( 522 | (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]) 523 | .unsqueeze(1) 524 | .repeat(1, total_num_anchors) 525 | ) 526 | gt_bboxes_per_image_b = ( 527 | (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]) 528 | .unsqueeze(1) 529 | .repeat(1, total_num_anchors) 530 | ) 531 | 532 | b_l = x_centers_per_image - gt_bboxes_per_image_l 533 | b_r = gt_bboxes_per_image_r - x_centers_per_image 534 | b_t = y_centers_per_image - gt_bboxes_per_image_t 535 | b_b = gt_bboxes_per_image_b - y_centers_per_image 536 | bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) 537 | 538 | is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 539 | is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 540 | # in fixed center 541 | 542 | center_radius = 5 543 | 544 | gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( 545 | 1, total_num_anchors 546 | ) - center_radius * expanded_strides_per_image.unsqueeze(0) 547 | gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( 548 | 1, total_num_anchors 549 | ) + center_radius * expanded_strides_per_image.unsqueeze(0) 550 | gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( 551 | 1, total_num_anchors 552 | ) - center_radius * expanded_strides_per_image.unsqueeze(0) 553 | gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( 554 | 1, total_num_anchors 555 | ) + center_radius * expanded_strides_per_image.unsqueeze(0) 556 | 557 | c_l = x_centers_per_image - gt_bboxes_per_image_l 558 | c_r = gt_bboxes_per_image_r - x_centers_per_image 559 | c_t = y_centers_per_image - gt_bboxes_per_image_t 560 | c_b = gt_bboxes_per_image_b - y_centers_per_image 561 | center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) 562 | is_in_centers = center_deltas.min(dim=-1).values > 0.0 563 | is_in_centers_all = is_in_centers.sum(dim=0) > 0 564 | 565 | # in boxes and in centers 566 | is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all 567 | 568 | is_in_boxes_and_center = ( 569 | is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] 570 | ) 571 | return is_in_boxes_anchor, is_in_boxes_and_center 572 | 573 | def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): 574 | # Dynamic K 575 | # --------------------------------------------------------------- 576 | matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) 577 | 578 | ious_in_boxes_matrix = pair_wise_ious 579 | n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) 580 | topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) 581 | dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) 582 | dynamic_ks = dynamic_ks.tolist() 583 | for gt_idx in range(num_gt): 584 | _, pos_idx = torch.topk( 585 | cost[gt_idx], k=dynamic_ks[gt_idx], largest=False 586 | ) 587 | matching_matrix[gt_idx][pos_idx] = 1 588 | 589 | del topk_ious, dynamic_ks, pos_idx 590 | 591 | anchor_matching_gt = matching_matrix.sum(0) 592 | if (anchor_matching_gt > 1).sum() > 0: 593 | _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) 594 | matching_matrix[:, anchor_matching_gt > 1] *= 0 595 | matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1 596 | fg_mask_inboxes = matching_matrix.sum(0) > 0 597 | num_fg = fg_mask_inboxes.sum().item() 598 | 599 | fg_mask[fg_mask.clone()] = fg_mask_inboxes 600 | 601 | matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) 602 | gt_matched_classes = gt_classes[matched_gt_inds] 603 | 604 | pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[ 605 | fg_mask_inboxes 606 | ] 607 | return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds 608 | 609 | class IOUloss(nn.Module): 610 | def __init__(self, reduction="none", loss_type="iou"): 611 | super(IOUloss, self).__init__() 612 | self.reduction = reduction 613 | self.loss_type = loss_type 614 | 615 | def forward(self, pred, target): 616 | assert pred.shape[0] == target.shape[0] 617 | 618 | pred = pred.view(-1, 4) 619 | target = target.view(-1, 4) 620 | tl = torch.max( 621 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 622 | ) 623 | br = torch.min( 624 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 625 | ) 626 | 627 | area_p = torch.prod(pred[:, 2:], 1) 628 | area_g = torch.prod(target[:, 2:], 1) 629 | 630 | en = (tl < br).type(tl.type()).prod(dim=1) 631 | area_i = torch.prod(br - tl, 1) * en 632 | area_u = area_p + area_g - area_i 633 | iou = (area_i) / (area_u + 1e-16) 634 | 635 | if self.loss_type == "iou": 636 | loss = 1 - iou ** 2 637 | elif self.loss_type == "giou": 638 | c_tl = torch.min( 639 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 640 | ) 641 | c_br = torch.max( 642 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 643 | ) 644 | area_c = torch.prod(c_br - c_tl, 1) 645 | giou = iou - (area_c - area_u) / area_c.clamp(1e-16) 646 | loss = 1 - giou.clamp(min=-1.0, max=1.0) 647 | 648 | if self.reduction == "mean": 649 | loss = loss.mean() 650 | elif self.reduction == "sum": 651 | loss = loss.sum() 652 | 653 | return loss -------------------------------------------------------------------------------- /transvcl/models/yolo_pafpn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .darknet import CSPDarknet 8 | from .network_blocks import BaseConv, CSPLayer, DWConv 9 | 10 | 11 | class YOLOPAFPN(nn.Module): 12 | """ 13 | YOLOv3 model. Darknet 53 is the default backbone of this model. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | depth=1.0, 19 | width=1.0, 20 | in_features=("dark3", "dark4", "dark5"), 21 | in_channels=[256, 512, 1024], 22 | depthwise=False, 23 | act="silu", 24 | ): 25 | super().__init__() 26 | self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) 27 | self.in_features = in_features 28 | self.in_channels = in_channels 29 | Conv = DWConv if depthwise else BaseConv 30 | 31 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 32 | self.lateral_conv0 = BaseConv( 33 | int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act 34 | ) 35 | self.C3_p4 = CSPLayer( 36 | int(2 * in_channels[1] * width), 37 | int(in_channels[1] * width), 38 | round(3 * depth), 39 | False, 40 | depthwise=depthwise, 41 | act=act, 42 | ) # cat 43 | 44 | self.reduce_conv1 = BaseConv( 45 | int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act 46 | ) 47 | self.C3_p3 = CSPLayer( 48 | int(2 * in_channels[0] * width), 49 | int(in_channels[0] * width), 50 | round(3 * depth), 51 | False, 52 | depthwise=depthwise, 53 | act=act, 54 | ) 55 | 56 | # bottom-up conv 57 | self.bu_conv2 = Conv( 58 | int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act 59 | ) 60 | self.C3_n3 = CSPLayer( 61 | int(2 * in_channels[0] * width), 62 | int(in_channels[1] * width), 63 | round(3 * depth), 64 | False, 65 | depthwise=depthwise, 66 | act=act, 67 | ) 68 | 69 | # bottom-up conv 70 | self.bu_conv1 = Conv( 71 | int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act 72 | ) 73 | self.C3_n4 = CSPLayer( 74 | int(2 * in_channels[1] * width), 75 | int(in_channels[2] * width), 76 | round(3 * depth), 77 | False, 78 | depthwise=depthwise, 79 | act=act, 80 | ) 81 | 82 | def forward(self, input): 83 | """ 84 | Args: 85 | inputs: input images. 86 | 87 | Returns: 88 | Tuple[Tensor]: FPN feature. 89 | """ 90 | 91 | # backbone 92 | out_features = self.backbone(input) 93 | features = [out_features[f] for f in self.in_features] 94 | [x2, x1, x0] = features 95 | 96 | fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 97 | f_out0 = self.upsample(fpn_out0) # 512/16 98 | f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 99 | f_out0 = self.C3_p4(f_out0) # 1024->512/16 100 | 101 | fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 102 | f_out1 = self.upsample(fpn_out1) # 256/8 103 | f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 104 | pan_out2 = self.C3_p3(f_out1) # 512->256/8 105 | 106 | p_out1 = self.bu_conv2(pan_out2) # 256->256/16 107 | p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 108 | pan_out1 = self.C3_n3(p_out1) # 512->512/16 109 | 110 | p_out0 = self.bu_conv1(pan_out1) # 512->512/32 111 | p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 112 | pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 113 | 114 | outputs = (pan_out2, pan_out1, pan_out0) 115 | return outputs 116 | -------------------------------------------------------------------------------- /transvcl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .boxes import * 2 | from .dist import * -------------------------------------------------------------------------------- /transvcl/utils/boxes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): 5 | box_corner = prediction.new(prediction.shape) 6 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 7 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 8 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 9 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 10 | prediction[:, :, :4] = box_corner[:, :, :4] 11 | 12 | output = [None for _ in range(len(prediction))] 13 | for i, image_pred in enumerate(prediction): 14 | 15 | # If none are remaining => process next image 16 | if not image_pred.size(0): 17 | continue 18 | # Get score and class with highest confidence 19 | class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) 20 | 21 | conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() 22 | # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) 23 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 24 | detections = detections[conf_mask] 25 | if not detections.size(0): 26 | continue 27 | 28 | if class_agnostic: 29 | nms_out_index = torchvision.ops.nms( 30 | detections[:, :4], 31 | detections[:, 4] * detections[:, 5], 32 | nms_thre, 33 | ) 34 | else: 35 | nms_out_index = torchvision.ops.batched_nms( 36 | detections[:, :4], 37 | detections[:, 4] * detections[:, 5], 38 | detections[:, 6], 39 | nms_thre, 40 | ) 41 | 42 | detections = detections[nms_out_index] 43 | if output[i] is None: 44 | output[i] = detections 45 | else: 46 | output[i] = torch.cat((output[i], detections)) 47 | 48 | return output 49 | 50 | def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): 51 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: 52 | raise IndexError 53 | 54 | if xyxy: 55 | tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) 56 | br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) 57 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 58 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 59 | else: 60 | tl = torch.max( 61 | (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), 62 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), 63 | ) 64 | br = torch.min( 65 | (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), 66 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), 67 | ) 68 | 69 | area_a = torch.prod(bboxes_a[:, 2:], 1) 70 | area_b = torch.prod(bboxes_b[:, 2:], 1) 71 | en = (tl < br).type(tl.type()).prod(dim=2) 72 | area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) 73 | return area_i / (area_a[:, None] + area_b - area_i) -------------------------------------------------------------------------------- /transvcl/utils/dist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import functools 5 | import os 6 | import pickle 7 | import time 8 | from contextlib import contextmanager 9 | from loguru import logger 10 | 11 | import numpy as np 12 | 13 | import torch 14 | from torch import distributed as dist 15 | 16 | __all__ = [ 17 | "get_num_devices", 18 | "wait_for_the_master", 19 | "is_main_process", 20 | "synchronize", 21 | "get_world_size", 22 | "get_rank", 23 | "get_local_rank", 24 | "get_local_size", 25 | "time_synchronized", 26 | "gather", 27 | "all_gather", 28 | ] 29 | 30 | _LOCAL_PROCESS_GROUP = None 31 | 32 | 33 | def get_num_devices(): 34 | gpu_list = os.getenv('CUDA_VISIBLE_DEVICES', None) 35 | if gpu_list is not None: 36 | return len(gpu_list.split(',')) 37 | else: 38 | devices_list_info = os.popen("nvidia-smi -L") 39 | devices_list_info = devices_list_info.read().strip().split("\n") 40 | return len(devices_list_info) 41 | 42 | 43 | @contextmanager 44 | def wait_for_the_master(local_rank: int): 45 | """ 46 | Make all processes waiting for the master to do some task. 47 | """ 48 | if local_rank > 0: 49 | dist.barrier() 50 | yield 51 | if local_rank == 0: 52 | if not dist.is_available(): 53 | return 54 | if not dist.is_initialized(): 55 | return 56 | else: 57 | dist.barrier() 58 | 59 | 60 | def synchronize(): 61 | """ 62 | Helper function to synchronize (barrier) among all processes when using distributed training 63 | """ 64 | if not dist.is_available(): 65 | return 66 | if not dist.is_initialized(): 67 | return 68 | world_size = dist.get_world_size() 69 | if world_size == 1: 70 | return 71 | dist.barrier() 72 | 73 | 74 | def get_world_size() -> int: 75 | if not dist.is_available(): 76 | return 1 77 | if not dist.is_initialized(): 78 | return 1 79 | return dist.get_world_size() 80 | 81 | 82 | def get_rank() -> int: 83 | if not dist.is_available(): 84 | return 0 85 | if not dist.is_initialized(): 86 | return 0 87 | return dist.get_rank() 88 | 89 | 90 | def get_local_rank() -> int: 91 | """ 92 | Returns: 93 | The rank of the current process within the local (per-machine) process group. 94 | """ 95 | if _LOCAL_PROCESS_GROUP is None: 96 | return get_rank() 97 | 98 | if not dist.is_available(): 99 | return 0 100 | if not dist.is_initialized(): 101 | return 0 102 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 103 | 104 | 105 | def get_local_size() -> int: 106 | """ 107 | Returns: 108 | The size of the per-machine process group, i.e. the number of processes per machine. 109 | """ 110 | if not dist.is_available(): 111 | return 1 112 | if not dist.is_initialized(): 113 | return 1 114 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 115 | 116 | 117 | def is_main_process() -> bool: 118 | return get_rank() == 0 119 | 120 | 121 | @functools.lru_cache() 122 | def _get_global_gloo_group(): 123 | """ 124 | Return a process group based on gloo backend, containing all the ranks 125 | The result is cached. 126 | """ 127 | if dist.get_backend() == "nccl": 128 | return dist.new_group(backend="gloo") 129 | else: 130 | return dist.group.WORLD 131 | 132 | 133 | def _serialize_to_tensor(data, group): 134 | backend = dist.get_backend(group) 135 | assert backend in ["gloo", "nccl"] 136 | device = torch.device("cpu" if backend == "gloo" else "cuda") 137 | 138 | buffer = pickle.dumps(data) 139 | if len(buffer) > 1024 ** 3: 140 | logger.warning( 141 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 142 | get_rank(), len(buffer) / (1024 ** 3), device 143 | ) 144 | ) 145 | storage = torch.ByteStorage.from_buffer(buffer) 146 | tensor = torch.ByteTensor(storage).to(device=device) 147 | return tensor 148 | 149 | 150 | def _pad_to_largest_tensor(tensor, group): 151 | """ 152 | Returns: 153 | list[int]: size of the tensor, on each rank 154 | Tensor: padded tensor that has the max size 155 | """ 156 | world_size = dist.get_world_size(group=group) 157 | assert ( 158 | world_size >= 1 159 | ), "comm.gather/all_gather must be called from ranks within the given group!" 160 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 161 | size_list = [ 162 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 163 | for _ in range(world_size) 164 | ] 165 | dist.all_gather(size_list, local_size, group=group) 166 | size_list = [int(size.item()) for size in size_list] 167 | 168 | max_size = max(size_list) 169 | 170 | # we pad the tensor because torch all_gather does not support 171 | # gathering tensors of different shapes 172 | if local_size != max_size: 173 | padding = torch.zeros( 174 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 175 | ) 176 | tensor = torch.cat((tensor, padding), dim=0) 177 | return size_list, tensor 178 | 179 | 180 | def all_gather(data, group=None): 181 | """ 182 | Run all_gather on arbitrary picklable data (not necessarily tensors). 183 | 184 | Args: 185 | data: any picklable object 186 | group: a torch process group. By default, will use a group which 187 | contains all ranks on gloo backend. 188 | Returns: 189 | list[data]: list of data gathered from each rank 190 | """ 191 | if get_world_size() == 1: 192 | return [data] 193 | if group is None: 194 | group = _get_global_gloo_group() 195 | if dist.get_world_size(group) == 1: 196 | return [data] 197 | 198 | tensor = _serialize_to_tensor(data, group) 199 | 200 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 201 | max_size = max(size_list) 202 | 203 | # receiving Tensor from all ranks 204 | tensor_list = [ 205 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 206 | for _ in size_list 207 | ] 208 | dist.all_gather(tensor_list, tensor, group=group) 209 | 210 | data_list = [] 211 | for size, tensor in zip(size_list, tensor_list): 212 | buffer = tensor.cpu().numpy().tobytes()[:size] 213 | data_list.append(pickle.loads(buffer)) 214 | 215 | return data_list 216 | 217 | 218 | def gather(data, dst=0, group=None): 219 | """ 220 | Run gather on arbitrary picklable data (not necessarily tensors). 221 | 222 | Args: 223 | data: any picklable object 224 | dst (int): destination rank 225 | group: a torch process group. By default, will use a group which 226 | contains all ranks on gloo backend. 227 | 228 | Returns: 229 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 230 | an empty list. 231 | """ 232 | if get_world_size() == 1: 233 | return [data] 234 | if group is None: 235 | group = _get_global_gloo_group() 236 | if dist.get_world_size(group=group) == 1: 237 | return [data] 238 | rank = dist.get_rank(group=group) 239 | 240 | tensor = _serialize_to_tensor(data, group) 241 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 242 | 243 | # receiving Tensor from all ranks 244 | if rank == dst: 245 | max_size = max(size_list) 246 | tensor_list = [ 247 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 248 | for _ in size_list 249 | ] 250 | dist.gather(tensor, tensor_list, dst=dst, group=group) 251 | 252 | data_list = [] 253 | for size, tensor in zip(size_list, tensor_list): 254 | buffer = tensor.cpu().numpy().tobytes()[:size] 255 | data_list.append(pickle.loads(buffer)) 256 | return data_list 257 | else: 258 | dist.gather(tensor, [], dst=dst, group=group) 259 | return [] 260 | 261 | 262 | def shared_random_seed(): 263 | """ 264 | Returns: 265 | int: a random number that is the same across all workers. 266 | If workers need a shared RNG, they can use this shared seed to 267 | create one. 268 | All workers must call this function, otherwise it will deadlock. 269 | """ 270 | ints = np.random.randint(2 ** 31) 271 | all_ints = all_gather(ints) 272 | return all_ints[0] 273 | 274 | 275 | def time_synchronized(): 276 | """pytorch-accurate time""" 277 | if torch.cuda.is_available(): 278 | torch.cuda.synchronize() 279 | return time.time() 280 | -------------------------------------------------------------------------------- /transvcl/weights/pretrained_models.txt: -------------------------------------------------------------------------------- 1 | model_1 (trained in fully supervised setting on VCSL dataset and used to reproduce results in Table 1): 2 | https://drive.google.com/file/d/19H45GO_mVpwVpcPQIOyseCBmjpdYtYv4/view?usp=sharing 3 | 4 | model_2 (trained with weakly semi-supervised setting on VCSL and FIVR&SVD and used to reproduce results in Table 5): 5 | https://drive.google.com/file/d/1fRfmalK2jD6KuyOITIoOu8qrXhm7iDz0/view?usp=sharing --------------------------------------------------------------------------------