├── data └── .gitkeep ├── src ├── models │ ├── __init__.py │ ├── loss.py │ ├── blocks.py │ ├── model.py │ └── swin_transformer.py ├── utils │ ├── __init__.py │ └── basic_utils.py ├── datasets │ ├── __init__.py │ └── mad.py ├── main.py └── trainer.py ├── figs ├── framework.png └── visualization.png ├── requirements.txt ├── preprocess ├── proc_mad_anno.py └── encode_text_by_clip.py ├── conf └── soonet_mad.yaml └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SOONet -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_utils import * -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mad import MADDataset -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afcedf/SOONet/HEAD/figs/framework.png -------------------------------------------------------------------------------- /figs/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afcedf/SOONet/HEAD/figs/visualization.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | coloredlogs 3 | easydict 4 | transformers 5 | tqdm 6 | h5py 7 | fvcore 8 | terminaltables 9 | multimethod 10 | tensorboard 11 | numpy -------------------------------------------------------------------------------- /preprocess/proc_mad_anno.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import csv 5 | 6 | splits = ["train", "val", "test"] 7 | root_dir = "data/mad/annotations" 8 | 9 | for split in splits: 10 | with open(os.path.join(root_dir, "MAD_{}.json".format(split)), 'r') as f: 11 | raw_anns = json.load(f) 12 | 13 | 14 | annos = list() 15 | for qid, ann in raw_anns.items(): 16 | vid = ann["movie"] 17 | duration = ann["movie_duration"] 18 | spos, epos = ann["ext_timestamps"] 19 | query = re.sub("\n", "", ann["sentence"]) 20 | 21 | annos.append([str(qid), str(vid), str(duration), str(spos), str(epos), query]) 22 | 23 | with open("data/mad/annotations/{}.txt".format(split), 'w') as f: 24 | for anno in annos: 25 | f.writelines(" | ".join(anno) + "\n") -------------------------------------------------------------------------------- /conf/soonet_mad.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2022 2 | 3 | CUDNN: 4 | DETERMINISTIC: False 5 | BENCHMARK: False 6 | 7 | DATASET: "mad" 8 | DATA: 9 | DATA_DIR: "data/mad" 10 | PRE_LOAD: False 11 | 12 | MODEL: 13 | HIDDEN_DIM: 512 14 | SNIPPET_LENGTH: 10 15 | SCALE_NUM: 4 16 | ENABLE_STAGE2: True 17 | STAGE2_POOL: "mean" 18 | STAGE2_TOPK: 100 19 | ENABLE_NMS: False 20 | 21 | LOSS: 22 | TEMPE: 0.01 23 | Q2V: 24 | CTX_WEIGHT: 1.0 25 | CTN_WEIGHT: 1.0 26 | V2Q: 27 | CTX_WEIGHT: 1.0 28 | CTN_WEIGHT: 1.0 29 | MIN_IOU: 0.3 30 | REGRESS: 31 | ENABLE: True 32 | WEIGHT: 20.0 33 | IOU_THRESH: 0 34 | REDUCE: "mean" 35 | 36 | 37 | OPTIMIZER: 38 | LR: 0.001 39 | WD: 0.0 40 | LR_DECAY: 0.1 41 | LR_DECAY_STEP: 40000 42 | 43 | 44 | TRAIN: 45 | BATCH_SIZE: 32 46 | WORKERS: 8 47 | NUM_EPOCH: 20 48 | LOG_STEP: 200 49 | EVAL_STEP: 2000 50 | 51 | TEST: 52 | BATCH_SIZE: 64 53 | WORKERS: 8 54 | TOPK: 100 55 | EVAL_TOPKS: [1, 5, 10, 50, 100] 56 | EVAL_TIOUS: [0.1, 0.3, 0.5] 57 | -------------------------------------------------------------------------------- /preprocess/encode_text_by_clip.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | 8 | 9 | def extract_sentence_feat(model_name): 10 | tokenizer = CLIPTokenizer.from_pretrained(model_name) 11 | model = CLIPTextModel.from_pretrained(model_name) 12 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 13 | model = model.to(device) 14 | model.eval() 15 | 16 | with h5py.File("data/mad/features/CLIP_language_sentence_features.h5", 'w') as f: 17 | write_h5(f, "train", tokenizer, model, device) 18 | write_h5(f, "val", tokenizer, model, device) 19 | write_h5(f, "test", tokenizer, model, device) 20 | 21 | 22 | def write_h5(h5_handler, split, tokenizer, model, device): 23 | qids, texts = list(), list() 24 | with open(f"data/mad/annotations/{split}.txt") as f: 25 | for line in f.readlines(): 26 | texts.append(line.strip().split(" | ")[-1]) 27 | qids.append(line.strip().split(" | ")[0]) 28 | 29 | print(f"split: {split}, text num: {len(texts)}") 30 | batch_size = 10000.0 31 | batch_num = math.ceil(len(texts) / batch_size) 32 | 33 | sent_feats = list() 34 | for i in range(batch_num): 35 | batches = texts[int(i*batch_size):int((i+1)*batch_size)] 36 | with torch.no_grad(): 37 | inputs = tokenizer(batches, 38 | padding="max_length", 39 | truncation=True, 40 | max_length=77, 41 | return_tensors="pt" 42 | ) 43 | output = model(input_ids=inputs.input_ids.to(device), 44 | attention_mask=inputs.attention_mask.to(device) 45 | ) 46 | sent_feats.append(output.pooler_output.cpu().numpy()) 47 | 48 | sent_feats = np.concatenate(sent_feats, axis=0) 49 | for qid, feat in zip(qids, sent_feats): 50 | h5_handler.create_dataset(f"{qid}", data=feat) 51 | 52 | 53 | 54 | if __name__ == "__main__": 55 | extract_sentence_feat("openai/clip-vit-base-patch32") -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from easydict import EasyDict as edict 4 | import torch 5 | import torch.utils.data as data 6 | 7 | from .datasets import * 8 | from .trainer import Trainer 9 | from .utils import set_seed 10 | 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser("Setting for training SOONet Models") 15 | 16 | parser.add_argument("--exp_path", type=str) 17 | parser.add_argument("--config_name", type=str) 18 | parser.add_argument("--device_id", type=int, default=0) 19 | parser.add_argument("--mode", type=str, default="train") 20 | 21 | opt = parser.parse_args() 22 | 23 | config_path = "conf/{}.yaml".format(opt.config_name) 24 | with open(config_path, 'r') as f: 25 | cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 26 | cfg.device_id = opt.device_id 27 | torch.cuda.set_device(opt.device_id) 28 | set_seed(cfg.SEED) 29 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 30 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 31 | 32 | dset = cfg.DATASET 33 | if dset.lower() == "mad": 34 | trainset = MADDataset("train", cfg, pre_load=cfg.DATA.PRE_LOAD) if opt.mode == "train" else list() 35 | testset = MADDataset("test", cfg, pre_load=cfg.DATA.PRE_LOAD) 36 | else: 37 | raise NotImplementedError 38 | 39 | print("Train batch num: {}, Test batch num: {}".format(len(trainset), len(testset))) 40 | print(cfg) 41 | 42 | if opt.mode == "train": 43 | train_loader = data.DataLoader(trainset, 44 | batch_size=1, 45 | num_workers=cfg.TRAIN.WORKERS, 46 | shuffle=False, 47 | collate_fn=trainset.collate_fn, 48 | drop_last=False 49 | ) 50 | 51 | test_loader = data.DataLoader(testset, 52 | batch_size=1, 53 | num_workers=cfg.TEST.WORKERS, 54 | shuffle=False, 55 | collate_fn=testset.collate_fn, 56 | drop_last=False 57 | ) 58 | 59 | trainer = Trainer(mode=opt.mode, save_or_load_path=opt.exp_path, cfg=cfg) 60 | 61 | if opt.mode == "train": 62 | trainer.train(train_loader, test_loader) 63 | elif opt.mode == "eval": 64 | trainer.eval(test_loader) 65 | elif opt.mode == "test": 66 | trainer.test(test_loader) 67 | else: 68 | raise ValueError(f'The value of mode {opt.mode} is not in ["train", "eval", "test"]') 69 | 70 | 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding in Long Videos 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2303.08345) 4 | 5 | This repository is an official implementation of [SOONet](https://arxiv.org/abs/2303.08345). SOONet is an end-to-end framework for temporal grounding in long videos. It manages to model an hours-long video with one-time network execution, alleviating the inefficiency issue caused by the sliding window pipeline. 6 | 7 | ![Framework](figs/framework.png) 8 | 9 | ## 📢 News 10 | - [2023.9.29] Code is released. 11 | - [2023.7.14] Our paper has been accepted to ICCV 2023! 12 | 13 | ## 🚀 Preparation 14 | 15 | ### 1. Install dependencies 16 | The code requires python and we recommend you to create a new environment using conda. 17 | 18 | ```bash 19 | conda create -n soonet python=3.8 20 | ``` 21 | 22 | Then install the dependencies with pip. 23 | 24 | ```bash 25 | conda activate soonet 26 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### 2. Download data 31 | - You should request access to the MAD dataset from [official webpage](https://github.com/Soldelli/MAD). Noded that all our experiments are implemented on MAD-v1. 32 | - Upon completion of the download, extract the zip file contents and allocate the data to the "data/mad" directory. 33 | 34 | ### 3. Data preprocess 35 | 36 | Use the following commands to convert the annotation format and extract the sentence features. 37 | 38 | ```bash 39 | python preprocess/proc_mad_anno.py 40 | python preprocess/encode_text_by_clip.py 41 | ``` 42 | 43 | The final data folder structure should looks like 44 | ``` 45 | data 46 | └───mad/ 47 | │ └───annotations/ 48 | │ └───MAD_train.json 49 | │ └───MAD_val.json 50 | │ └───MAD_test.json 51 | │ └───train.txt 52 | │ └───val.txt 53 | │ └───test.txt 54 | │ └───features/ 55 | │ └───CLIP_frame_features_5fps.h5 56 | │ └───CLIP_language_features_MAD_test.h5 57 | │ └───CLIP_language_sentence_features.h5 58 | │ └───CLIP_language_tokens_features.h5 59 | ``` 60 | 61 | ## 🔥 Experiments 62 | 63 | ### Training 64 | 65 | Run the following commands for training model on MAD dataset: 66 | 67 | ```bash 68 | python -m src.main --exp_path /path/to/output --config_name soonet_mad --device_id 0 --mode train 69 | ``` 70 | 71 | Please be advised that utilizing a batch size of 32 will consume approximately 70G of GPU memory. 72 | Decreasing the batch size can prevent out-of-memory, but it may also have a detrimental impact on accuracy. 73 | 74 | ### Inference 75 | 76 | Once training is finished, you can use the following commands to inference on the test set of MAD. 77 | 78 | ```bash 79 | python -m src.main --exp_path /path/to/training/output --config_name soonet_mad --device_id 0 --mode test 80 | ``` 81 | 82 | 83 | ## 😊 Citation 84 | 85 | If you find this work useful in your research, please cite our paper: 86 | 87 | ```bibtex 88 | @InProceedings{Pan_2023_ICCV, 89 | author = {Pan, Yulin and He, Xiangteng and Gong, Biao and Lv, Yiliang and Shen, Yujun and Peng, Yuxin and Zhao, Deli}, 90 | title = {Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding in Long Videos}, 91 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 92 | month = {October}, 93 | year = {2023}, 94 | pages = {13767-13777} 95 | } 96 | ``` 97 | 98 | ## 🙏🏻 Acknowledgement 99 | 100 | Our code references the following projects. Many thanks to the authors. 101 | 102 | * [Swin-Transformer-1D](https://github.com/meraks/Swin-Transformer-1D.git) 103 | * [Tensorflow-Ranking](https://github.com/tensorflow/ranking.git) 104 | -------------------------------------------------------------------------------- /src/utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import logging, logging.handlers 4 | import coloredlogs 5 | import torch 6 | 7 | 8 | def get_logger(name, log_file_path=None, fmt="%(asctime)s %(name)s: %(message)s", 9 | print_lev=logging.DEBUG, write_lev=logging.INFO): 10 | logger = logging.getLogger(name) 11 | # Add file handler 12 | if log_file_path: 13 | formatter = logging.Formatter(fmt) 14 | file_handler = logging.handlers.RotatingFileHandler(log_file_path) 15 | file_handler.setLevel(write_lev) 16 | file_handler.setFormatter(formatter) 17 | logger.addHandler(file_handler) 18 | # Add stream handler 19 | coloredlogs.install(level=print_lev, logger=logger, 20 | fmt="%(asctime)s %(name)s %(message)s") 21 | return logger 22 | 23 | 24 | def count_parameters(model): 25 | train_params = 0 26 | for name, parameter in model.named_parameters(): 27 | if not parameter.requires_grad: continue 28 | train_params += parameter.numel() 29 | print(f"Total Trainable Params: {train_params}") 30 | 31 | 32 | 33 | def set_seed(seed, use_cuda=True): 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | if use_cuda: 38 | torch.cuda.manual_seed_all(seed) 39 | 40 | 41 | def compute_tiou(pred, gt): 42 | intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0])) 43 | union = max(pred[1], gt[1]) - min(pred[0], gt[0]) 44 | return float(intersection) / (union + 1e-9) 45 | 46 | 47 | def compute_overlap(pred, gt): 48 | # check format 49 | assert isinstance(pred, list) and isinstance(gt, list) 50 | pred_is_list = isinstance(pred[0], list) 51 | gt_is_list = isinstance(gt[0], list) 52 | pred = pred if pred_is_list else [pred] 53 | gt = gt if gt_is_list else [gt] 54 | # compute overlap 55 | pred, gt = np.array(pred), np.array(gt) 56 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0]) 57 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1]) 58 | inter = np.maximum(0.0, inter_right - inter_left) 59 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0]) 60 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1]) 61 | union = np.maximum(1e-12, union_right - union_left) 62 | overlap = 1.0 * inter / union 63 | # reformat output 64 | overlap = overlap if gt_is_list else overlap[:, 0] 65 | overlap = overlap if pred_is_list else overlap[0] 66 | return overlap 67 | 68 | 69 | def time_to_index(start_time, end_time, num_units, duration): 70 | s_times = np.arange(0, num_units).astype(np.float32) / float(num_units) * duration 71 | e_times = np.arange(1, num_units + 1).astype(np.float32) / float(num_units) * duration 72 | candidates = np.stack([np.repeat(s_times[:, None], repeats=num_units, axis=1), 73 | np.repeat(e_times[None, :], repeats=num_units, axis=0)], axis=2).reshape((-1, 2)) 74 | overlaps = compute_overlap(candidates.tolist(), [start_time, end_time]).reshape(num_units, num_units) 75 | start_index = np.argmax(overlaps) // num_units 76 | end_index = np.argmax(overlaps) % num_units 77 | return start_index, end_index, overlaps 78 | 79 | 80 | def index_to_time(start_index, end_index, num_units, duration): 81 | s_times = np.arange(0, num_units).astype(np.float32) * duration / float(num_units) 82 | e_times = np.arange(1, num_units + 1).astype(np.float32) * duration / float(num_units) 83 | start_time = s_times[start_index] 84 | end_time = e_times[end_index] 85 | return start_time, end_time 86 | 87 | 88 | def fetch_feats_by_index(ori_feats, indices): 89 | B, L = indices.shape 90 | filtered_feats = ori_feats[torch.arange(B)[:, None], indices] 91 | return filtered_feats 92 | 93 | 94 | 95 | class Evaluator(object): 96 | 97 | def __init__(self, tiou_threshold=[0.1, 0.3, 0.5], topks=[1, 5, 10, 50, 100]): 98 | self.tiou_threshold = tiou_threshold 99 | self.topks = topks 100 | 101 | def eval_instance(self, pred, gt, topk): 102 | """ Compute Recall@topk at predefined tiou threshold for instance 103 | Args: 104 | pred: predictions of starting/end position; list of [start,end] 105 | gt: ground-truth of starting/end position; [start,end] 106 | topk: rank of predictions; int 107 | Return: 108 | correct: flag of correct at predefined tiou threshold [0.3,0.5,0.7] 109 | """ 110 | correct = {str(tiou):0 for tiou in self.tiou_threshold} 111 | find = {str(tiou):False for tiou in self.tiou_threshold} 112 | if len(pred) == 0: 113 | return correct 114 | 115 | if len(pred) > topk: 116 | pred = pred[:topk] 117 | 118 | best_tiou = 0 119 | for loc in pred: 120 | cur_tiou = compute_tiou(loc, gt) 121 | 122 | if cur_tiou > best_tiou: 123 | best_tiou = cur_tiou 124 | 125 | for tiou in self.tiou_threshold: 126 | if (not find[str(tiou)]) and (cur_tiou >= tiou): 127 | correct[str(tiou)] = 1 128 | find[str(tiou)] = True 129 | 130 | return correct, best_tiou 131 | 132 | def eval(self, preds, gts): 133 | """ Compute R@1 and R@5 at predefined tiou threshold [0.3,0.5,0.7] 134 | Args: 135 | pred: predictions consisting of starting/end position; list 136 | gt: ground-truth of starting/end position; [start,end] 137 | Return: 138 | correct: flag of correct at predefined tiou threshold [0.3,0.5,0.7] 139 | """ 140 | num_instances = float(len(preds)) 141 | miou = 0 142 | all_rank = dict() 143 | for tiou in self.tiou_threshold: 144 | for topk in self.topks: 145 | all_rank["R{}-{}".format(topk, tiou)] = 0 146 | 147 | for pred,gt in zip(preds, gts): 148 | for topk in self.topks: 149 | correct, iou = self.eval_instance(pred, gt, topk=topk) 150 | for tiou in self.tiou_threshold: 151 | all_rank["R{}-{}".format(topk, tiou)] += correct[str(tiou)] 152 | 153 | # miou += iou 154 | 155 | for tiou in self.tiou_threshold: 156 | for topk in self.topks: 157 | all_rank["R{}-{}".format(topk, tiou)] /= num_instances 158 | 159 | # miou /= float(num_instances) 160 | 161 | return all_rank, miou -------------------------------------------------------------------------------- /src/models/loss.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/tensorflow/ranking/blob/master/tensorflow_ranking/python/losses_impl.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class ApproxNDCGLoss(nn.Module): 9 | 10 | def __init__(self, cfg): 11 | super().__init__() 12 | self.alpha = cfg.LOSS.TEMPE 13 | 14 | def forward(self, labels, logits, mask): 15 | if logits is None: 16 | return 0.0 17 | labels = labels / self.alpha 18 | logits = torch.where(mask.bool(), logits, torch.min(logits, dim=1, keepdim=True)[0] - 1e3 * torch.ones_like(logits)) 19 | logits = logits / self.alpha 20 | ranks = self.approx_ranks(logits) 21 | 22 | loss = 1.0 - self.ndcg(labels, ranks) 23 | return loss.sum() 24 | 25 | 26 | def approx_ranks(self, logits): 27 | r"""Computes approximate ranks given a list of logits. 28 | Given a list of logits, the rank of an item in the list is one plus the total 29 | number of items with a larger logit. In other words, 30 | rank_i = 1 + \sum_{j \neq i} I_{s_j > s_i}, 31 | where "I" is the indicator function. The indicator function can be 32 | approximated by a generalized sigmoid: 33 | I_{s_j < s_i} \approx 1/(1 + exp(-(s_j - s_i)/temperature)). 34 | This function approximates the rank of an item using this sigmoid 35 | approximation to the indicator function. This technique is at the core 36 | of "A general approximation framework for direct optimization of 37 | information retrieval measures" by Qin et al. 38 | Args: 39 | logits: A `Tensor` with shape [batch_size, list_size]. Each value is the 40 | ranking score of the corresponding item. 41 | Returns: 42 | A `Tensor` of ranks with the same shape as logits. 43 | """ 44 | list_size = logits.size(1) 45 | x = logits.unsqueeze(2).repeat(1, 1, list_size) 46 | y = logits.unsqueeze(1).repeat(1, list_size, 1) 47 | pairs = torch.sigmoid(y - x) 48 | return pairs.sum(dim=-1) + .5 49 | 50 | 51 | def ndcg(self, labels, ranks): 52 | """Computes NDCG from labels and ranks. 53 | Args: 54 | labels: A `Tensor` with shape [batch_size, list_size], representing graded 55 | relevance. 56 | ranks: A `Tensor` of the same shape as labels, or [1, list_size], or None. 57 | If ranks=None, we assume the labels are sorted in their rank. 58 | perm_mat: A `Tensor` with shape [batch_size, list_size, list_size] or None. 59 | Permutation matrices with rows correpond to the ranks and columns 60 | correspond to the indices. An argmax over each row gives the index of the 61 | element at the corresponding rank. 62 | Returns: 63 | A `tensor` of NDCG, ApproxNDCG, or ExpectedNDCG of shape [batch_size, 1]. 64 | """ 65 | discounts = 1. / torch.log1p(ranks.float()) 66 | gains = torch.pow(2., labels) - 1. 67 | 68 | dcg = (gains * discounts).sum(1, keepdim=True) 69 | normalized_dcg = dcg * self.inverse_max_dcg(labels) 70 | 71 | return normalized_dcg 72 | 73 | def inverse_max_dcg(self, labels, 74 | gain_fn=lambda labels: torch.pow(2.0, labels)-1., 75 | rank_discount_fn=lambda rank: 1./torch.log1p(rank), 76 | topn=None): 77 | ideal_sorted_labels = self.sort_by_scores(labels, topn=topn) 78 | rank = (torch.arange(ideal_sorted_labels.size(1)) + 1).to(labels.device) 79 | discounted_gain = gain_fn(ideal_sorted_labels) * rank_discount_fn(rank.float()) 80 | discounted_gain = discounted_gain.sum(1, keepdim=True) 81 | idcg = torch.where(torch.greater(discounted_gain, 0.0), 1./discounted_gain, torch.zeros_like(discounted_gain)) 82 | 83 | return idcg 84 | 85 | def sort_by_scores(self, scores, mask=None, topn=None): 86 | list_size = scores.size(1) 87 | if topn is None: 88 | topn = list_size 89 | topn = min(topn, list_size) 90 | if mask is not None: 91 | scores = torch.where(mask.bool(), scores, torch.min(scores)) 92 | sorted_scores, sorted_indices = torch.topk(scores, topn) 93 | return sorted_scores 94 | 95 | 96 | 97 | class IOULoss(nn.Module): 98 | 99 | def __init__(self, cfg): 100 | super().__init__() 101 | self.cfg = cfg 102 | self.reduce = cfg.LOSS.REGRESS.REDUCE 103 | 104 | def forward(self, pred_left, pred_right, target_left, target_right, mask): 105 | target_left = target_left.repeat(1, pred_left.size(1)) 106 | target_right = target_right.repeat(1, pred_right.size(1)) 107 | intersect = torch.clamp(torch.min(target_right, pred_right) - torch.max(target_left, pred_left), 0) 108 | union = torch.clamp(torch.max(target_right, pred_right) - torch.min(target_left, pred_left), 0) 109 | 110 | iou = (intersect + 1e-8) / (union + 1e-8) 111 | 112 | loss = -torch.log(iou) 113 | if self.reduce == "mean": 114 | loss = (loss * mask).sum() / (mask.sum() + 1e-8) 115 | elif self.reduce == "sum": 116 | loss = (loss * mask).sum() 117 | else: 118 | raise NotImplementedError 119 | 120 | return iou, loss 121 | 122 | 123 | class HighLightLoss(nn.Module): 124 | 125 | def __init__(self, cfg): 126 | super().__init__() 127 | 128 | self.cfg = cfg 129 | 130 | def forward(self, labels, logits, mask, epsilon=1e-12): 131 | labels = labels.type(torch.float32) 132 | weights = torch.where(labels == 0.0, 1.0, 100.0) 133 | loss_per_location = nn.BCELoss(reduction='none')(logits, labels) 134 | loss_per_location = loss_per_location * weights 135 | mask = mask.type(torch.float32) 136 | loss = torch.sum(loss_per_location * mask) / (torch.sum(mask) + epsilon) 137 | return loss 138 | 139 | 140 | class NCELoss(nn.Module): 141 | 142 | def __init__(self, cfg): 143 | super().__init__() 144 | 145 | self.cfg = cfg 146 | 147 | def forward(self, labels, logits, mask, alpha): 148 | logits = logits / alpha 149 | n, d = logits.size() 150 | _, pos_idx = torch.max(labels, dim=1) 151 | pos_mask = torch.zeros_like(logits, dtype=torch.int32).to(logits.device) 152 | for i in range(pos_idx.size(0)): 153 | pos_mask[i][pos_idx[i]] = 1 154 | 155 | # neg_mask = torch.where(labels==0, 1, 0).bool() 156 | pos_dist = torch.masked_select(logits, mask=pos_mask.bool()).reshape(n, 1) 157 | neg_dist = torch.masked_select(logits, mask=(1-pos_mask).bool()).reshape(n, d-1) 158 | 159 | logits = torch.cat([pos_dist, neg_dist], dim=1) 160 | target = torch.zeros([n], dtype=torch.long, requires_grad=False).cuda() 161 | loss = F.cross_entropy(logits, target, reduction='mean') 162 | return loss -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from tqdm import tqdm 5 | import torch 6 | from torch.optim import AdamW, lr_scheduler 7 | 8 | from .models import SOONet 9 | from .utils import Evaluator, get_logger 10 | 11 | 12 | class Trainer(object): 13 | 14 | def __init__(self, mode, save_or_load_path, cfg): 15 | self.device = torch.device(cfg.device_id) if torch.cuda.is_available() else torch.device("cpu") 16 | self.model = SOONet(cfg) 17 | 18 | self.evaluator = Evaluator(tiou_threshold=cfg.TEST.EVAL_TIOUS, topks=cfg.TEST.EVAL_TOPKS) 19 | 20 | self.save_or_load_path = save_or_load_path 21 | log_dir = os.path.join(save_or_load_path, "log") 22 | ckpt_dir = os.path.join(save_or_load_path, "ckpt") 23 | if not os.path.exists(log_dir): 24 | os.makedirs(log_dir) 25 | if not os.path.exists(ckpt_dir): 26 | os.makedirs(ckpt_dir) 27 | 28 | self.log_dir = log_dir 29 | self.ckpt_dir = ckpt_dir 30 | self.cfg = cfg 31 | 32 | if mode == "train": 33 | with open(os.path.join(log_dir, "config.json"), 'w') as f: 34 | js = json.dumps(cfg, indent=2) 35 | f.write(js) 36 | 37 | self.optimizer = self.build_optimizer(cfg) 38 | self.scheduler = lr_scheduler.StepLR(self.optimizer, cfg.OPTIMIZER.LR_DECAY_STEP, 39 | gamma=cfg.OPTIMIZER.LR_DECAY, last_epoch=-1, verbose=False) 40 | 41 | 42 | def build_optimizer(self, cfg): 43 | no_decay = ['bias', 'layer_norm', 'LayerNorm'] 44 | optimizer_grouped_parameters = [ 45 | {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': cfg.OPTIMIZER.WD}, 46 | {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 47 | optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.OPTIMIZER.LR) 48 | 49 | return optimizer 50 | 51 | 52 | def train(self, train_loader, test_loader): 53 | logger = get_logger("TRAIN", log_file_path=os.path.join(self.log_dir, "train.log")) 54 | self.model.to(self.device) 55 | self.train_epoch(0, train_loader, test_loader, logger, self.cfg) 56 | 57 | 58 | def eval(self, test_loader): 59 | logger = get_logger("EVAL", log_file_path=os.path.join(self.log_dir, "eval.log")) 60 | 61 | resume_path = os.path.join(self.ckpt_dir, "best.pth") 62 | logger.info("Load trained model from: {}".format(resume_path)) 63 | 64 | saver_dict = torch.load(resume_path, map_location="cpu") 65 | state_dict = saver_dict["model"] 66 | self.model.load_state_dict(state_dict, strict=True) 67 | self.model.to(self.device) 68 | self.model.eval() 69 | logger.info("Load trained model succeed.") 70 | 71 | all_rank, miou = self.eval_epoch(test_loader) 72 | for k, v in all_rank.items(): 73 | logger.info("{}: {:.4f}".format(k, v)) 74 | 75 | 76 | def test(self, test_loader): 77 | logger = get_logger("TEST", log_file_path=os.path.join(self.log_dir, "test.log")) 78 | 79 | resume_path = os.path.join(self.ckpt_dir, "best.pth") 80 | logger.info("Load trained model from: {}".format(resume_path)) 81 | 82 | saver_dict = torch.load(resume_path, map_location="cpu") 83 | state_dict = saver_dict["model"] 84 | self.model.load_state_dict(state_dict, strict=True) 85 | self.model.to(self.device) 86 | self.model.eval() 87 | logger.info("Load trained model succeed.") 88 | 89 | start = time.time() 90 | test_instances = [] 91 | with torch.no_grad(): 92 | for batch in tqdm(test_loader, total=len(test_loader)): 93 | scores, bboxes = self.model( 94 | query_feats=batch["query_feats"].to(self.device), 95 | query_masks=batch["query_masks"].to(self.device), 96 | video_feats=batch["video_feats"].to(self.device), 97 | start_ts=batch["starts"].to(self.device), 98 | end_ts=batch["ends"].to(self.device), 99 | scale_boundaries=batch["scale_boundaries"].to(self.device), 100 | ) 101 | for i in range(len(bboxes)): 102 | instance = { 103 | "vid": batch["vid"], 104 | "duration": batch["duration"], 105 | "qid": batch["qids"][0][i], 106 | "text": batch["texts"][0][i], 107 | "timestamp": batch["timestamps"][i].numpy().tolist(), 108 | "pred_scores": scores[i], 109 | "pred_bboxes": bboxes[i] 110 | } 111 | test_instances.append(instance) 112 | 113 | logger.info("cost time: {}".format(time.time() - start)) 114 | result_path = os.path.join(self.log_dir, "infer_result.json") 115 | with open(result_path, 'w') as f: 116 | res = json.dumps(test_instances, indent=2) 117 | f.write(res) 118 | 119 | 120 | def train_epoch(self, epoch, train_loader, test_loader, logger, cfg): 121 | self.model.train() 122 | 123 | best_r1 = 0 124 | for i, batch in enumerate(train_loader): 125 | loss_dict = self.model( 126 | query_feats=batch["query_feats"].to(self.device), 127 | query_masks=batch["query_masks"].to(self.device), 128 | video_feats=batch["video_feats"].to(self.device), 129 | start_ts=batch["starts"].to(self.device), 130 | end_ts=batch["ends"].to(self.device), 131 | scale_boundaries=batch["scale_boundaries"].to(self.device), 132 | overlaps=batch["overlaps"].to(self.device), 133 | timestamps=batch["timestamps"].to(self.device), 134 | anchor_masks=batch["anchor_masks"].to(self.device) 135 | ) 136 | total_loss = loss_dict["total_loss"] 137 | self.optimizer.zero_grad() 138 | total_loss.backward() 139 | self.optimizer.step() 140 | self.scheduler.step() 141 | 142 | if i % cfg.TRAIN.LOG_STEP == 0: 143 | log_str = f"Step: {i}, " 144 | for k, v in loss_dict.items(): 145 | log_str += "{}: {:.3f}, ".format(k, v) 146 | logger.info(log_str[:-2]) 147 | 148 | if i > 0 and i % cfg.TRAIN.EVAL_STEP == 0: 149 | all_rank, miou = self.eval_epoch(test_loader) 150 | 151 | logger.info("step: {}".format(i)) 152 | for k, v in all_rank.items(): 153 | logger.info("{}: {:.4f}".format(k, v)) 154 | 155 | r1 = all_rank["R1-0.5"] 156 | if r1 > best_r1: 157 | best_r1 =r1 158 | saver_dict = { 159 | "step": i, 160 | "r1-0.5": r1, 161 | "model": self.model.state_dict(), 162 | "optimizer": self.optimizer.state_dict() 163 | } 164 | save_path = os.path.join(self.ckpt_dir, "best.pth") 165 | torch.save(saver_dict, save_path) 166 | 167 | self.model.train() 168 | 169 | logger.info("best R1-0.5: {:.4f}".format(best_r1)) 170 | 171 | 172 | 173 | def eval_epoch(self, test_loader): 174 | self.model.eval() 175 | 176 | preds, gts = list(), list() 177 | with torch.no_grad(): 178 | for batch in tqdm(test_loader, total=len(test_loader)): 179 | scores, bboxes = self.model( 180 | query_feats=batch["query_feats"].to(self.device), 181 | query_masks=batch["query_masks"].to(self.device), 182 | video_feats=batch["video_feats"].to(self.device), 183 | start_ts=batch["starts"].to(self.device), 184 | end_ts=batch["ends"].to(self.device), 185 | scale_boundaries=batch["scale_boundaries"].to(self.device), 186 | ) 187 | preds.extend(bboxes) 188 | gts.extend([i for i in batch["timestamps"].numpy()]) 189 | 190 | return self.evaluator.eval(preds, gts) -------------------------------------------------------------------------------- /src/datasets/mad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import random 4 | import math 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | from collections import defaultdict 8 | import torch 9 | import torch.utils.data as data 10 | 11 | from ..utils import compute_overlap 12 | 13 | 14 | class MADDataset(data.Dataset): 15 | 16 | def __init__(self, split, cfg, pre_load=False): 17 | super().__init__() 18 | 19 | self.split = split 20 | self.data_dir = cfg.DATA.DATA_DIR 21 | self.snippet_length = cfg.MODEL.SNIPPET_LENGTH 22 | self.scale_num = cfg.MODEL.SCALE_NUM 23 | self.max_anchor_length = self.snippet_length * 2**(self.scale_num - 1) 24 | if split == "train": 25 | epochs = cfg.TRAIN.NUM_EPOCH 26 | batch_size = cfg.TRAIN.BATCH_SIZE 27 | else: 28 | epochs = 1 29 | batch_size = 1000000 30 | 31 | self.q2v = dict() 32 | self.v2q = defaultdict(list) 33 | self.v2dur = dict() 34 | with open(os.path.join(self.data_dir, f"annotations/{split}.txt"), 'r') as f: 35 | for i, line in enumerate(f.readlines()): 36 | qid, vid, duration, start, end, text = line.strip().split(" | ") 37 | qid = int(qid) 38 | 39 | assert float(start) < float(end), \ 40 | "Wrong timestamps for {}: start >= end".format(qid) 41 | 42 | if vid not in self.v2dur: 43 | self.v2dur[vid] = float(duration) 44 | self.q2v[qid] = { 45 | "vid": vid, 46 | "duration": float(duration), 47 | "timestamps": [float(start), float(end)], 48 | "text": text.lower() 49 | } 50 | self.v2q[vid].append(qid) 51 | 52 | # generate training batch 53 | self.samples = list() 54 | for i_epoch in range(epochs): 55 | batches = list() 56 | for vid, qids in self.v2q.items(): 57 | cqids = qids.copy() 58 | if self.split == "train": 59 | random.shuffle(cqids) 60 | if len(cqids) % batch_size != 0: 61 | pad_num = batch_size - len(cqids) % batch_size 62 | cqids = cqids + cqids[:pad_num] 63 | 64 | steps = np.math.ceil(len(cqids) / batch_size) 65 | for j in range(steps): 66 | batches.append({"vid": vid, "qids": cqids[j*batch_size:(j+1)*batch_size]}) 67 | 68 | if self.split == "train": 69 | random.shuffle(batches) 70 | self.samples.extend(batches) 71 | 72 | self.vfeat_path = os.path.join(self.data_dir, "features/CLIP_frames_features_5fps.h5") 73 | self.qfeat_path = os.path.join(self.data_dir, "features/CLIP_language_sentence_features.h5") 74 | if pre_load: 75 | with h5py.File(self.vfeat_path, 'r') as f: 76 | self.vfeats = {m: np.asarray(f[m]) for m in self.v2q.keys()} 77 | with h5py.File(self.qfeat_path, 'r') as f: 78 | self.qfeats = {str(m): np.asarray(f[str(m)]) for m in self.q2v.keys()} 79 | else: 80 | self.vfeats, self.qfeats = None, None 81 | self.fps = 5.0 82 | 83 | 84 | def __len__(self): 85 | return len(self.samples) 86 | 87 | 88 | def __getitem__(self, idx): 89 | vid = self.samples[idx]["vid"] 90 | qids = self.samples[idx]["qids"] 91 | duration = self.v2dur[vid] 92 | 93 | if not self.vfeats: 94 | self.vfeats = h5py.File(self.vfeat_path, 'r') 95 | ori_video_feat = np.asarray(self.vfeats[vid]) 96 | ori_video_length, feat_dim = ori_video_feat.shape 97 | pad_video_length = int(np.math.ceil(ori_video_length / self.max_anchor_length) * self.max_anchor_length) 98 | pad_video_feat = np.zeros((pad_video_length, feat_dim), dtype=float) 99 | pad_video_feat[:ori_video_length, :] = ori_video_feat 100 | 101 | querys = { 102 | "texts": list(), 103 | "query_feats": list(), 104 | "query_masks": list(), 105 | "anchor_masks": list(), 106 | "starts": list(), 107 | "ends": list(), 108 | "overlaps": list(), 109 | "timestamps": list(), 110 | } 111 | scale_boundaries = [0] 112 | for qid in qids: 113 | text = self.q2v[qid]["text"] 114 | timestamps = self.q2v[qid]["timestamps"] 115 | if not self.qfeats: 116 | self.qfeats = h5py.File(self.qfeat_path, 'r') 117 | query_feat = np.asarray(self.qfeats[str(qid)]) 118 | query_length = query_feat.shape[0] 119 | query_mask = np.ones((query_length, ), dtype=float) 120 | 121 | # generate multi-level groundtruth 122 | masks, starts, ends, overlaps = list(), list(), list(), list() 123 | for i in range(self.scale_num): 124 | anchor_length = self.snippet_length * 2**i 125 | nfeats = math.ceil(ori_video_length / anchor_length) 126 | s_times = np.arange(0, nfeats).astype(np.float32) * (anchor_length / self.fps) 127 | e_times = np.minimum(duration, np.arange(1, nfeats + 1).astype(np.float32) * (anchor_length / self.fps)) 128 | candidates = np.stack([s_times, e_times], axis=1) 129 | overlap = compute_overlap(candidates.tolist(), timestamps) 130 | mask = np.ones((nfeats, ), dtype=int) 131 | 132 | pad_nfeats = math.ceil(pad_video_length / anchor_length) 133 | starts.append(self.pad(s_times, pad_nfeats)) 134 | ends.append(self.pad(e_times, pad_nfeats)) 135 | overlaps.append(self.pad(overlap, pad_nfeats)) 136 | masks.append(self.pad(mask, pad_nfeats)) 137 | 138 | if len(scale_boundaries) != self.scale_num + 1: 139 | scale_boundaries.append(scale_boundaries[-1] + pad_nfeats) 140 | 141 | starts = np.concatenate(starts, axis=0) 142 | ends = np.concatenate(ends, axis=0) 143 | overlaps = np.concatenate(overlaps, axis=0) 144 | masks = np.concatenate(masks, axis=0) 145 | 146 | querys["texts"].append(text) 147 | querys["query_feats"].append(torch.from_numpy(query_feat)) 148 | querys["query_masks"].append(torch.from_numpy(query_mask)) 149 | querys["anchor_masks"].append(torch.from_numpy(masks)) 150 | querys["starts"].append(torch.from_numpy(starts)) 151 | querys["ends"].append(torch.from_numpy(ends)) 152 | querys["overlaps"].append(torch.from_numpy(overlaps)) 153 | querys["timestamps"].append(torch.FloatTensor(timestamps)) 154 | 155 | instance = { 156 | "vid": vid, 157 | "duration": float(duration), 158 | "video_feats": torch.from_numpy(pad_video_feat).unsqueeze(0).float(), 159 | "scale_boundaries": torch.LongTensor(scale_boundaries), 160 | "qids": qids, 161 | "texts":querys["texts"], 162 | "query_feats": torch.stack(querys["query_feats"], dim=0).float(), 163 | "query_masks": torch.stack(querys["query_masks"], dim=0).float(), 164 | "anchor_masks": torch.stack(querys["anchor_masks"], dim=0), 165 | "starts": torch.stack(querys["starts"], dim=0), 166 | "ends": torch.stack(querys["ends"], dim=0), 167 | "overlaps": torch.stack(querys["overlaps"], dim=0), 168 | "timestamps": torch.stack(querys["timestamps"], dim=0) 169 | } 170 | return instance 171 | 172 | 173 | def pad(self, arr, pad_len): 174 | new_arr = np.zeros((pad_len, ), dtype=float) 175 | new_arr[:len(arr)] = arr 176 | return new_arr 177 | 178 | 179 | @staticmethod 180 | def collate_fn(data): 181 | all_items = data[0].keys() 182 | no_tensor_items = ["vid", "duration", "qids", "texts"] 183 | 184 | batch = {k: [d[k] for d in data] for k in all_items} 185 | for k in all_items: 186 | if k not in no_tensor_items: 187 | batch[k] = torch.cat(batch[k], dim=0) 188 | 189 | return batch 190 | 191 | 192 | 193 | if __name__ == "__main__": 194 | import yaml 195 | with open("conf/soonet_mad.yaml", 'r') as f: 196 | cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 197 | print(cfg) 198 | 199 | mad_dataset = MADDataset("train", cfg) 200 | data_loader = data.DataLoader(mad_dataset, 201 | batch_size=1, 202 | num_workers=4, 203 | shuffle=False, 204 | collate_fn=mad_dataset.collate_fn, 205 | drop_last=False 206 | ) 207 | 208 | for i, batch in enumerate(data_loader): 209 | for k, v in batch.items(): 210 | if isinstance(v, torch.Tensor): 211 | print("{}: {}".format(k, v.size())) 212 | else: 213 | print("{}: {}".format(k, v)) 214 | break -------------------------------------------------------------------------------- /src/models/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ..utils import fetch_feats_by_index 8 | 9 | 10 | 11 | class Q2VRankerStage1(nn.Module): 12 | 13 | def __init__(self, nlevel, hidden_dim): 14 | super().__init__() 15 | self.fc = nn.Linear(hidden_dim, hidden_dim) 16 | self.nlevel = nlevel 17 | 18 | def forward(self, ctx_feats, qfeat): 19 | qfeat = self.fc(qfeat) 20 | qv_ctx_scores = list() 21 | for i in range(self.nlevel): 22 | score = torch.einsum("bld,bd->bl", 23 | F.normalize(ctx_feats[i], p=2, dim=2), F.normalize(qfeat, p=2, dim=1)) 24 | qv_ctx_scores.append(score) 25 | 26 | return qv_ctx_scores 27 | 28 | 29 | class V2QRankerStage1(nn.Module): 30 | 31 | def __init__(self, nlevel, hidden_dim): 32 | super().__init__() 33 | self.fc = nn.Linear(hidden_dim, hidden_dim) 34 | self.nlevel = nlevel 35 | 36 | def forward(self, ctx_feats, qfeat): 37 | vq_ctx_scores = list() 38 | for i in range(self.nlevel): 39 | score = torch.einsum("bld,bd->bl", 40 | F.normalize(self.fc(ctx_feats[i]), p=2, dim=2), F.normalize(qfeat, p=2, dim=1)) 41 | vq_ctx_scores.append(score) 42 | 43 | return vq_ctx_scores 44 | 45 | 46 | class Q2VRankerStage2(nn.Module): 47 | 48 | def __init__(self, nlevel, hidden_dim, snippet_length=10, pool='mean'): 49 | super().__init__() 50 | self.nlevel = nlevel 51 | self.base_snippet_length = snippet_length 52 | self.qfc = nn.Linear(hidden_dim, hidden_dim) 53 | self.encoder = V2VAttention(hidden_dim) 54 | self.pool = pool 55 | 56 | def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores): 57 | qfeat = self.qfc(qfeat) 58 | 59 | qv_ctn_scores = list() 60 | qv_merge_scores = list() 61 | 62 | _, L, D = vfeats.size() 63 | ctn_feats = list() 64 | for i in range(self.nlevel): 65 | snippet_length = self.base_snippet_length * 2**i 66 | assert L // snippet_length == qv_ctx_scores[i].size(1), \ 67 | "{}, {}, {}, {}".format(i, L, snippet_length, qv_ctx_scores[i].size()) 68 | 69 | ctn_feat = vfeats.view(L//snippet_length, snippet_length, D).detach() 70 | if self.training: 71 | qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1, hit_indices[i]) 72 | ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i]) 73 | ctn_feat = self.encoder(ctn_feat, torch.ones(ctn_feat.size()[:2], device=ctn_feat.device)) 74 | ctn_feat = ctn_feat.unsqueeze(0) 75 | else: 76 | qv_ctx_score = fetch_feats_by_index(qv_ctx_scores[i], hit_indices[i]) 77 | B, K = hit_indices[i].shape 78 | ctn_feat = fetch_feats_by_index(ctn_feat.unsqueeze(0).repeat(B, 1, 1, 1), hit_indices[i]).view(B*K, snippet_length, D) 79 | ctn_feat = self.encoder(ctn_feat, torch.ones(ctn_feat.size()[:2], device=ctn_feat.device)) 80 | ctn_feat = ctn_feat.view(B, K, snippet_length, D) 81 | 82 | ctn_feats.append(ctn_feat) 83 | qv_ctn_score = torch.einsum("bkld,bd->bkl", 84 | F.normalize(ctn_feat, p=2, dim=3), F.normalize(qfeat, p=2, dim=1)) 85 | if self.pool == "mean": 86 | qv_ctn_score = torch.mean(qv_ctn_score, dim=2) 87 | elif self.pool == "max": 88 | qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2) 89 | else: 90 | raise NotImplementedError 91 | qv_ctn_scores.append(qv_ctn_score) 92 | qv_merge_scores.append(qv_ctx_score + qv_ctn_score) 93 | 94 | return qv_merge_scores, qv_ctn_scores, ctn_feats 95 | 96 | 97 | class V2QRankerStage2(nn.Module): 98 | 99 | def __init__(self, nlevel, hidden_dim): 100 | super().__init__() 101 | self.fc = nn.Linear(hidden_dim, hidden_dim) 102 | self.nlevel = nlevel 103 | 104 | def forward(self, ctn_feats, qfeat): 105 | vq_ctn_scores = list() 106 | for i in range(self.nlevel): 107 | score = torch.einsum("bkld,bd->bkl", 108 | F.normalize(self.fc(ctn_feats[i]), p=2, dim=3), F.normalize(qfeat, p=2, dim=1)) 109 | score = torch.mean(score, dim=2) 110 | vq_ctn_scores.append(score) 111 | 112 | return vq_ctn_scores 113 | 114 | class V2VAttention(nn.Module): 115 | 116 | def __init__(self, hidden_dim): 117 | super().__init__() 118 | self.posemb = PositionEncoding(max_len=400, dim=hidden_dim, dropout=0.0) 119 | self.encoder = MultiHeadAttention(dim=hidden_dim, n_heads=8, dropout=0.1) 120 | self.dropout = nn.Dropout(0.0) 121 | 122 | def forward(self, video_feats, video_masks): 123 | mask = torch.einsum("bm,bn->bmn", video_masks, video_masks).unsqueeze(1) 124 | residual = video_feats 125 | video_feats = video_feats + self.posemb(video_feats) 126 | out = self.encoder(query=video_feats, key=video_feats, value=video_feats, mask=mask) 127 | video_feats = self.dropout(residual + out) * video_masks.unsqueeze(2).float() 128 | return video_feats 129 | 130 | 131 | class BboxRegressor(nn.Module): 132 | 133 | def __init__(self, hidden_dim, enable_stage2=False): 134 | super().__init__() 135 | self.fc_ctx = nn.Linear(hidden_dim, hidden_dim) 136 | self.fc_q = nn.Linear(hidden_dim, hidden_dim) 137 | 138 | if enable_stage2: 139 | self.fc_ctn = nn.Linear(hidden_dim, hidden_dim) 140 | self.attn = SelfAttention(hidden_dim) 141 | self.predictor = nn.Sequential( 142 | nn.Linear(2*hidden_dim, hidden_dim), 143 | nn.ReLU(), 144 | nn.Linear(hidden_dim, 2) 145 | ) 146 | else: 147 | self.predictor = nn.Sequential( 148 | nn.Linear(hidden_dim, hidden_dim), 149 | nn.ReLU(), 150 | nn.Linear(hidden_dim, 2) 151 | ) 152 | self.enable_stage2 = enable_stage2 153 | 154 | def forward(self, ctx_feats, ctn_feats, qfeat): 155 | qfeat = self.fc_q(qfeat) 156 | 157 | ctx_feats = torch.cat(ctx_feats, dim=1) 158 | ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu(qfeat.unsqueeze(1)) 159 | 160 | if self.enable_stage2 and ctn_feats: 161 | ctn_fuse_feats = list() 162 | for i in range(len(ctn_feats)): 163 | out = F.relu(self.fc_ctn(ctn_feats[i])) * F.relu(qfeat.unsqueeze(1).unsqueeze(1)) 164 | out = self.attn(out) 165 | ctn_fuse_feats.append(out) 166 | ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1) 167 | fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1) 168 | else: 169 | fuse_feats = ctx_fuse_feats 170 | 171 | out = self.predictor(fuse_feats) 172 | return out 173 | 174 | 175 | class SelfAttention(nn.Module): 176 | 177 | def __init__(self, hidden_dim): 178 | super().__init__() 179 | self.fc1 = nn.Linear(hidden_dim, hidden_dim//2) 180 | self.relu = nn.ReLU() 181 | self.fc2 = nn.Linear(hidden_dim//2, 1) 182 | 183 | def forward(self, x): 184 | att = self.fc2(self.relu(self.fc1(x))).squeeze(3) 185 | att = F.softmax(att, dim=2).unsqueeze(3) 186 | out = torch.sum(x * att, dim=2) 187 | return out 188 | 189 | 190 | class PositionEncoding(nn.Module): 191 | 192 | def __init__(self, max_len, dim, dropout=0.0): 193 | super(PositionEncoding, self).__init__() 194 | 195 | self.embed = nn.Embedding(max_len, dim) 196 | self.relu = nn.ReLU() 197 | self.dropout = nn.Dropout(dropout) 198 | 199 | def forward(self, x): 200 | batch_size, seq_len = x.shape[:2] 201 | pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device) 202 | pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1) 203 | pos_emb = self.dropout(self.relu(self.embed(pos_ids))) 204 | 205 | return pos_emb 206 | 207 | 208 | 209 | class MultiHeadAttention(nn.Module): 210 | 211 | def __init__(self, dim, n_heads, dropout=0.0): 212 | super(MultiHeadAttention, self).__init__() 213 | 214 | self.dim = dim 215 | self.n_heads = n_heads 216 | self.head_dim = dim // n_heads 217 | 218 | self.to_q = nn.Linear(dim, dim) 219 | self.to_k = nn.Linear(dim, dim) 220 | self.to_v = nn.Linear(dim, dim) 221 | 222 | self.dropout = nn.Dropout(dropout) 223 | self.softmax = nn.Softmax(dim=-1) 224 | 225 | def transpose_for_scores(self, x): 226 | new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim) 227 | x = x.view(*new_x_shape) 228 | return x.permute(0, 2, 1, 3) # (N, nh, L, dh) 229 | 230 | def forward(self, query, key, value, mask): 231 | q = self.to_q(query) 232 | k = self.to_k(key) 233 | v = self.to_v(value) 234 | 235 | q_trans = self.transpose_for_scores(q) 236 | k_trans = self.transpose_for_scores(k) 237 | v_trans = self.transpose_for_scores(v) 238 | 239 | att = torch.matmul(q_trans, k_trans.transpose(-1, -2)) # (N, nh, Lq, L) 240 | att = att / math.sqrt(self.head_dim) 241 | att = mask_logits(att, mask) 242 | att = self.softmax(att) 243 | att = self.dropout(att) 244 | 245 | ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh) 246 | ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh) 247 | shape = ctx_v.size()[:-2] + (self.dim, ) 248 | ctx_v = ctx_v.view(*shape) # (N, Lq, D) 249 | return ctx_v 250 | 251 | 252 | def mask_logits(inputs, mask, mask_value=-1e30): 253 | mask = mask.type(torch.float32) 254 | return inputs + (1.0 - mask) * mask_value 255 | -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .blocks import * 6 | from .swin_transformer import SwinTransformerV2_1D 7 | from .loss import * 8 | from ..utils import fetch_feats_by_index, compute_tiou 9 | 10 | 11 | class SOONet(nn.Module): 12 | 13 | def __init__(self, cfg): 14 | 15 | super().__init__() 16 | nscales = cfg.MODEL.SCALE_NUM 17 | hidden_dim = cfg.MODEL.HIDDEN_DIM 18 | snippet_length = cfg.MODEL.SNIPPET_LENGTH 19 | enable_stage2 = cfg.MODEL.ENABLE_STAGE2 20 | stage2_pool = cfg.MODEL.STAGE2_POOL 21 | stage2_topk = cfg.MODEL.STAGE2_TOPK 22 | topk = cfg.TEST.TOPK 23 | 24 | self.video_encoder = SwinTransformerV2_1D( 25 | patch_size=snippet_length, 26 | in_chans=hidden_dim, 27 | embed_dim=hidden_dim, 28 | depths=[2]*nscales, 29 | num_heads=[8]*nscales, 30 | window_size=[64]*nscales, 31 | mlp_ratio=2., 32 | qkv_bias=True, 33 | drop_rate=0., 34 | attn_drop_rate=0., 35 | drop_path_rate=0.1, 36 | norm_layer=nn.LayerNorm, 37 | patch_norm=True, 38 | use_checkpoint=False, 39 | pretrained_window_sizes=[0]*nscales 40 | ) 41 | 42 | self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim) 43 | self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim) 44 | if enable_stage2: 45 | self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim, snippet_length, stage2_pool) 46 | self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim) 47 | self.regressor = BboxRegressor(hidden_dim, enable_stage2) 48 | self.rank_loss = ApproxNDCGLoss(cfg) 49 | self.reg_loss = IOULoss(cfg) 50 | 51 | self.nscales = nscales 52 | self.enable_stage2 = enable_stage2 53 | self.stage2_topk = stage2_topk 54 | self.cfg = cfg 55 | self.topk = topk 56 | self.enable_nms = cfg.MODEL.ENABLE_NMS 57 | 58 | 59 | def forward(self, **kwargs): 60 | if self.training: 61 | return self.forward_train(**kwargs) 62 | else: 63 | return self.forward_test(**kwargs) 64 | 65 | 66 | def forward_train(self, 67 | query_feats=None, 68 | query_masks=None, 69 | video_feats=None, 70 | start_ts=None, 71 | end_ts=None, 72 | scale_boundaries=None, 73 | overlaps=None, 74 | timestamps=None, 75 | anchor_masks=None, 76 | **kwargs): 77 | 78 | sent_feat = query_feats 79 | ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1)) 80 | qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat) 81 | vq_ctx_scores = self.v2q_stage1(ctx_feats, sent_feat) 82 | if self.enable_stage2: 83 | hit_indices = list() 84 | filtered_ctx_feats = list() 85 | starts = list() 86 | ends = list() 87 | stage2_overlaps = list() 88 | for i in range(self.nscales): 89 | scale_first = scale_boundaries[i] 90 | scale_last = scale_boundaries[i+1] 91 | 92 | gt = overlaps[:, scale_first:scale_last] 93 | indices = torch.nonzero(gt.sum(0) > 0, as_tuple=True)[0] 94 | hit_indices.append(indices) 95 | 96 | filtered_ctx_feats.append(torch.index_select(ctx_feats[i], 1, indices)) 97 | starts.append(torch.index_select(start_ts[:, scale_first:scale_last], 1, indices)) 98 | ends.append(torch.index_select(end_ts[:, scale_first:scale_last], 1, indices)) 99 | stage2_overlaps.append(torch.index_select(overlaps[:, scale_first:scale_last], 1, indices)) 100 | 101 | starts = torch.cat(starts, dim=1) 102 | ends = torch.cat(ends, dim=1) 103 | stage2_overlaps = torch.cat(stage2_overlaps, dim=1) 104 | 105 | qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2( 106 | video_feats, sent_feat, hit_indices, qv_ctx_scores) 107 | vq_ctn_scores = self.v2q_stage2(ctn_feats, sent_feat) 108 | ctx_feats = filtered_ctx_feats 109 | else: 110 | ctn_feats = None 111 | qv_merge_scores = qv_ctx_scores 112 | starts = start_ts 113 | ends = end_ts 114 | stage2_overlaps = None 115 | 116 | bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat) 117 | 118 | qv_ctx_scores = torch.sigmoid(torch.cat(qv_ctx_scores, dim=1)) 119 | qv_ctn_scores = torch.sigmoid(torch.cat(qv_ctn_scores, dim=1)) 120 | vq_ctx_scores = torch.sigmoid(torch.cat(vq_ctx_scores, dim=1)) 121 | vq_ctn_scores = torch.sigmoid(torch.cat(vq_ctn_scores, dim=1)) 122 | final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1)) 123 | 124 | loss_dict = self.loss(qv_ctx_scores, qv_ctn_scores, vq_ctx_scores, vq_ctn_scores, bbox_bias, 125 | timestamps, overlaps, stage2_overlaps, starts, ends, anchor_masks) 126 | 127 | return loss_dict 128 | 129 | def forward_test(self, 130 | query_feats=None, 131 | query_masks=None, 132 | video_feats=None, 133 | start_ts=None, 134 | end_ts=None, 135 | scale_boundaries=None, 136 | **kwargs): 137 | 138 | ori_ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1)) 139 | batch_size = self.cfg.TEST.BATCH_SIZE 140 | query_num = len(query_feats) 141 | num_batches = math.ceil(query_num / batch_size) 142 | 143 | merge_scores, merge_bboxes = list(), list() 144 | for bid in range(num_batches): 145 | sent_feat = query_feats[bid*int(batch_size):(bid+1)*int(batch_size)] 146 | qv_ctx_scores = self.q2v_stage1(ori_ctx_feats, sent_feat) 147 | if self.enable_stage2: 148 | hit_indices = list() 149 | starts = list() 150 | ends = list() 151 | filtered_ctx_feats = list() 152 | for i in range(self.nscales): 153 | scale_first = scale_boundaries[i] 154 | scale_last = scale_boundaries[i+1] 155 | 156 | _, indices = torch.sort(qv_ctx_scores[i], dim=1, descending=True) 157 | indices = indices[:, :self.stage2_topk] 158 | hit_indices.append(indices) 159 | 160 | filtered_ctx_feats.append(fetch_feats_by_index(ori_ctx_feats[i].repeat(indices.size(0), 1, 1), indices)) 161 | starts.append(fetch_feats_by_index(start_ts[bid*int(batch_size):(bid+1)*int(batch_size), scale_first:scale_last], indices)) 162 | ends.append(fetch_feats_by_index(end_ts[bid*int(batch_size):(bid+1)*int(batch_size), scale_first:scale_last], indices)) 163 | 164 | starts = torch.cat(starts, dim=1) 165 | ends = torch.cat(ends, dim=1) 166 | 167 | qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2( 168 | video_feats, sent_feat, hit_indices, qv_ctx_scores) 169 | ctx_feats = filtered_ctx_feats 170 | else: 171 | ctx_feats = ori_ctx_feats 172 | ctn_feats = None 173 | qv_merge_scores = qv_ctx_scores 174 | starts = start_ts[bid*int(batch_size):(bid+1)*int(batch_size)] 175 | ends = end_ts[bid*int(batch_size):(bid+1)*int(batch_size)] 176 | 177 | bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat) 178 | final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1)) 179 | 180 | 181 | pred_scores, pred_bboxes = list(), list() 182 | 183 | final_scores = final_scores.cpu().numpy() 184 | starts = starts.cpu().numpy() 185 | ends = ends.cpu().numpy() 186 | bbox_bias = bbox_bias.cpu().numpy() 187 | 188 | rank_ids = np.argsort(final_scores, axis=1) 189 | rank_ids = rank_ids[:, ::-1] 190 | query_num = len(rank_ids) 191 | ori_start = starts[np.arange(query_num)[:, None], rank_ids] 192 | ori_end = ends[np.arange(query_num)[:, None], rank_ids] 193 | duration = ori_end - ori_start 194 | sebias = bbox_bias[np.arange(query_num)[:, None], rank_ids] 195 | sbias, ebias = sebias[:, :, 0], sebias[:, :, 1] 196 | pred_start = np.maximum(0, ori_start + sbias * duration) 197 | pred_end = ori_end + ebias * duration 198 | 199 | pred_scores = final_scores[np.arange(query_num)[:, None], rank_ids] 200 | pred_bboxes = np.stack([pred_start, pred_end], axis=2) 201 | if self.enable_nms: 202 | nms_res = list() 203 | for i in range(query_num): 204 | bbox_nms = self.nms(pred_bboxes[i], thresh=0.3, topk=self.topk) 205 | nms_res.append(bbox_nms) 206 | pred_bboxes = nms_res 207 | else: 208 | pred_scores = pred_scores[:, :self.topk].tolist() 209 | pred_bboxes = pred_bboxes[:, :self.topk, :].tolist() 210 | 211 | merge_scores.extend(pred_scores) 212 | merge_bboxes.extend(pred_bboxes) 213 | 214 | return merge_scores, merge_bboxes 215 | 216 | 217 | def loss(self, 218 | qv_ctx_scores, 219 | qv_ctn_scores, 220 | vq_ctx_scores, 221 | vq_ctn_scores, 222 | bbox_bias, 223 | timestamps, 224 | overlaps, 225 | stage2_overlaps, 226 | starts, 227 | ends, 228 | anchor_masks): 229 | qv_ctx_loss = self.rank_loss(overlaps, qv_ctx_scores, mask=anchor_masks) 230 | vq_overlaps, vq_ctx_scores = self.filter_anchor_by_iou(overlaps, vq_ctx_scores) 231 | vq_ctx_loss = self.rank_loss(vq_overlaps, vq_ctx_scores, mask=torch.ones_like(vq_ctx_scores)) 232 | 233 | qv_ctn_loss, vq_ctn_loss, iou_loss = 0.0, 0.0, 0.0 234 | if self.cfg.MODEL.ENABLE_STAGE2: 235 | qv_ctn_loss = self.rank_loss(stage2_overlaps, qv_ctn_scores, mask=torch.ones_like(qv_ctn_scores)) 236 | vq_overlaps_s2, vq_ctn_scores = self.filter_anchor_by_iou(stage2_overlaps, vq_ctn_scores) 237 | vq_ctn_loss = self.rank_loss(vq_overlaps_s2, vq_ctn_scores, mask=torch.ones_like(vq_ctn_scores)) 238 | 239 | if self.cfg.LOSS.REGRESS.ENABLE: 240 | sbias = bbox_bias[:, :, 0] 241 | ebias = bbox_bias[:, :, 1] 242 | duration = ends - starts 243 | pred_start = starts + sbias * duration 244 | pred_end = ends + ebias * duration 245 | 246 | if self.cfg.MODEL.ENABLE_STAGE2: 247 | iou_mask = stage2_overlaps > self.cfg.LOSS.REGRESS.IOU_THRESH 248 | else: 249 | iou_mask = overlaps > self.cfg.LOSS.REGRESS.IOU_THRESH 250 | _, iou_loss = self.reg_loss(pred_start, pred_end, timestamps[:, 0:1], timestamps[:, 1:2], iou_mask) 251 | 252 | total_loss = self.cfg.LOSS.Q2V.CTX_WEIGHT * qv_ctx_loss + \ 253 | self.cfg.LOSS.Q2V.CTN_WEIGHT * qv_ctn_loss + \ 254 | self.cfg.LOSS.V2Q.CTX_WEIGHT * vq_ctx_loss + \ 255 | self.cfg.LOSS.V2Q.CTN_WEIGHT * vq_ctn_loss + \ 256 | self.cfg.LOSS.REGRESS.WEIGHT * iou_loss 257 | 258 | loss_dict = { 259 | "qv_ctx_loss": qv_ctx_loss, 260 | "qv_ctn_loss": qv_ctn_loss, 261 | "vq_ctx_loss": vq_ctx_loss, 262 | "vq_ctn_loss": vq_ctn_loss, 263 | "reg_loss": iou_loss, 264 | "total_loss": total_loss 265 | } 266 | return loss_dict 267 | 268 | 269 | def filter_anchor_by_iou(self, gt, pred): 270 | indicator = (torch.sum((gt > self.cfg.LOSS.V2Q.MIN_IOU).float(), dim=0, keepdim=False) > 0).long() 271 | moment_num = torch.sum(indicator) 272 | _, index = torch.sort(indicator, descending=True) 273 | index = index[:moment_num] 274 | gt = torch.index_select(gt, 1, index).transpose(0, 1) 275 | pred = torch.index_select(pred, 1, index).transpose(0, 1) 276 | return gt, pred 277 | 278 | 279 | def nms(self, pred, thresh=0.3, topk=5): 280 | nms_res = list() 281 | mask = [False] * len(pred) 282 | for i in range(len(pred)): 283 | f = pred[i].copy() 284 | if not mask[i]: 285 | nms_res.append(f) 286 | if len(nms_res) >= topk: 287 | break 288 | for j in range(i, len(pred)): 289 | tiou = compute_tiou(pred[i], pred[j]) 290 | if tiou > thresh: 291 | mask[j] = True 292 | del mask 293 | return nms_res -------------------------------------------------------------------------------- /src/models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/meraks/Swin-Transformer-1D/blob/main/SwinTransformer.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint as checkpoint 8 | from torch.nn.init import trunc_normal_ 9 | 10 | 11 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 12 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 13 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 14 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 15 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 16 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 17 | 'survival rate' as the argument. 18 | """ 19 | if drop_prob == 0. or not training: 20 | return x 21 | keep_prob = 1 - drop_prob 22 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 23 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 24 | if keep_prob > 0.0 and scale_by_keep: 25 | random_tensor.div_(keep_prob) 26 | return x * random_tensor 27 | 28 | 29 | class DropPath(nn.Module): 30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 31 | """ 32 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 33 | super(DropPath, self).__init__() 34 | self.drop_prob = drop_prob 35 | self.scale_by_keep = scale_by_keep 36 | 37 | def forward(self, x): 38 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 39 | 40 | def extra_repr(self): 41 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 42 | 43 | 44 | class Mlp(nn.Module): 45 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 46 | super().__init__() 47 | out_features = out_features or in_features 48 | hidden_features = hidden_features or in_features 49 | self.fc1 = nn.Linear(in_features, hidden_features) 50 | self.act = act_layer() 51 | self.fc2 = nn.Linear(hidden_features, out_features) 52 | self.drop = nn.Dropout(drop) 53 | 54 | def forward(self, x): 55 | x = self.fc1(x) 56 | x = self.act(x) 57 | x = self.drop(x) 58 | x = self.fc2(x) 59 | x = self.drop(x) 60 | return x 61 | 62 | 63 | def window_partition(x, window_size): 64 | """ 65 | Args: 66 | x: (B, L, C) 67 | window_size (int): window size 68 | Returns: 69 | windows: (num_windows*B, window_size, C) 70 | """ 71 | B, L, C = x.shape 72 | x = x.view(B, L // window_size, window_size, C) 73 | windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C) 74 | return windows 75 | 76 | 77 | def window_reverse(windows, window_size, L): 78 | """ 79 | Args: 80 | windows: (num_windows*B, window_size, window_size, C) 81 | window_size (int): Window size 82 | L (int): sequence length 83 | Returns: 84 | x: (B, L, C) 85 | """ 86 | B = int(windows.shape[0] / (L / window_size)) 87 | x = windows.view(B, L // window_size, window_size, -1) 88 | x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1) 89 | return x 90 | 91 | 92 | class WindowAttention_1D(nn.Module): 93 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 94 | It supports both of shifted and non-shifted window. 95 | Args: 96 | dim (int): Number of input channels. 97 | window_size (int): The height and width of the window. 98 | num_heads (int): Number of attention heads. 99 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 100 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 101 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 102 | pretrained_window_size (int): The height and width of the window in pre-training. 103 | """ 104 | 105 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., 106 | pretrained_window_size=0): 107 | 108 | super().__init__() 109 | self.dim = dim 110 | self.window_size = window_size # Wl 111 | self.pretrained_window_size = pretrained_window_size 112 | self.num_heads = num_heads 113 | 114 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) 115 | 116 | # mlp to generate continuous relative position bias 117 | self.cpb_mlp = nn.Sequential(nn.Linear(1, 512, bias=True), 118 | nn.ReLU(inplace=True), 119 | nn.Linear(512, num_heads, bias=False)) 120 | 121 | # get relative_coords_table 122 | relative_coords_l = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32) 123 | relative_coords_table = torch.stack( 124 | torch.meshgrid([relative_coords_l], indexing='ij')).permute(1, 0).contiguous().unsqueeze(0) # 1, 2*Wl-1, 1 125 | if pretrained_window_size > 0: 126 | relative_coords_table[:, :, :] /= (pretrained_window_size - 1) 127 | else: 128 | relative_coords_table[:, :, :] /= (self.window_size - 1) 129 | relative_coords_table *= 8 # normalize to -8, 8 130 | relative_coords_table = torch.sign(relative_coords_table) * torch.log2( 131 | torch.abs(relative_coords_table) + 1.0) / np.log2(8) 132 | 133 | self.register_buffer("relative_coords_table", relative_coords_table) 134 | 135 | # get pair-wise relative position index for each token inside the window 136 | coords_l = torch.arange(self.window_size) 137 | coords = torch.stack(torch.meshgrid([coords_l], indexing='ij')) # 1, Wl 138 | coords_flatten = torch.flatten(coords, 1) # 1, Wl 139 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 1, Wl, Wl 140 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wl, Wl, 1 141 | relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 142 | relative_position_index = relative_coords.sum(-1) # Wl, Wl 143 | self.register_buffer("relative_position_index", relative_position_index) 144 | 145 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 146 | if qkv_bias: 147 | self.q_bias = nn.Parameter(torch.zeros(dim)) 148 | self.v_bias = nn.Parameter(torch.zeros(dim)) 149 | else: 150 | self.q_bias = None 151 | self.v_bias = None 152 | self.attn_drop = nn.Dropout(attn_drop) 153 | self.proj = nn.Linear(dim, dim) 154 | self.proj_drop = nn.Dropout(proj_drop) 155 | self.softmax = nn.Softmax(dim=-1) 156 | 157 | def forward(self, x, mask=None): 158 | """ 159 | Args: 160 | x: input features with shape of (num_windows*B, N, C) 161 | mask: (0/-inf) mask with shape of (num_windows, Wl, Wl) or None 162 | """ 163 | B_, N, C = x.shape 164 | qkv_bias = None 165 | if self.q_bias is not None: 166 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 167 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 168 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 169 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 170 | 171 | # cosine attention 172 | attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) 173 | logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=attn.device))).exp() 174 | attn = attn * logit_scale 175 | 176 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) 177 | relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( 178 | self.window_size, self.window_size, -1) # Wl,l,nH 179 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wl, Wl 180 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias) 181 | attn = attn + relative_position_bias.unsqueeze(0) 182 | 183 | if mask is not None: 184 | nW = mask.shape[0] 185 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 186 | attn = attn.view(-1, self.num_heads, N, N) 187 | attn = self.softmax(attn) 188 | else: 189 | attn = self.softmax(attn) 190 | 191 | attn = self.attn_drop(attn) 192 | 193 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 194 | x = self.proj(x) 195 | x = self.proj_drop(x) 196 | return x 197 | 198 | 199 | def compute_mask(L, window_size, shift_size): 200 | Lp = int(np.ceil(L / window_size)) * window_size 201 | img_mask = torch.zeros((1, Lp, 1)) # 1 Lp 1 202 | pad_size = int(Lp - L) 203 | if (pad_size == 0) or (pad_size + shift_size == window_size): 204 | segs = (slice(-window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) 205 | elif pad_size + shift_size > window_size: 206 | seg1 = int(window_size * 2 - L + shift_size) 207 | segs = ( 208 | slice(-seg1), slice(-seg1, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) 209 | elif pad_size + shift_size < window_size: 210 | seg1 = int(window_size * 2 - L + shift_size) 211 | segs = ( 212 | slice(-window_size), slice(-window_size, -seg1), slice(-seg1, -shift_size), slice(-shift_size, None)) 213 | cnt = 0 214 | for d in segs: 215 | img_mask[:, d, :] = cnt 216 | cnt += 1 217 | mask_windows = window_partition(img_mask, window_size) # nW, ws, 1 218 | mask_windows = mask_windows.squeeze(-1) # nW, ws 219 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 220 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 221 | return attn_mask 222 | 223 | 224 | class SwinTransformerBlock_1D(nn.Module): 225 | r""" Swin Transformer Block. 226 | Args: 227 | dim (int): Number of input channels. 228 | input_resolution (int): Input resulotion. 229 | num_heads (int): Number of attention heads. 230 | window_size (int): Window size. 231 | shift_size (int): Shift size for SW-MSA. 232 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 233 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 234 | drop (float, optional): Dropout rate. Default: 0.0 235 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 236 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 237 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 238 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 239 | pretrained_window_size (int): Window size in pre-training. 240 | """ 241 | 242 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 243 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 244 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): 245 | super().__init__() 246 | self.dim = dim 247 | self.num_heads = num_heads 248 | self.window_size = window_size 249 | self.shift_size = shift_size 250 | self.mlp_ratio = mlp_ratio 251 | 252 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 253 | 254 | self.norm1 = norm_layer(dim) 255 | self.attn = WindowAttention_1D( 256 | dim, window_size=self.window_size, num_heads=num_heads, 257 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 258 | pretrained_window_size=pretrained_window_size) 259 | 260 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 261 | self.norm2 = norm_layer(dim) 262 | mlp_hidden_dim = int(dim * mlp_ratio) 263 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 264 | 265 | 266 | def forward(self, x): 267 | B, L, C = x.shape 268 | 269 | attn_mask = compute_mask(L, self.window_size, self.shift_size).to(x.device) 270 | 271 | shortcut = x 272 | # x = x.view(B, L, C) 273 | 274 | # padding x 275 | pad_r = (self.window_size - L % self.window_size) % self.window_size 276 | x = F.pad(x, (0, 0, 0, pad_r)) 277 | _, Lp, _ = x.shape 278 | 279 | # cyclic shift 280 | if self.shift_size > 0: 281 | shifted_x = torch.roll(x, shifts=(-self.shift_size), dims=(1)) 282 | else: 283 | shifted_x = x 284 | 285 | # partition windows 286 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, C 287 | x_windows = x_windows.view(-1, self.window_size, C) # nW*B, window_siz, C 288 | 289 | # W-MSA/SW-MSA 290 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size, C 291 | 292 | # merge windows 293 | attn_windows = attn_windows.view(-1, self.window_size, C) 294 | shifted_x = window_reverse(attn_windows, self.window_size, Lp) # B L' C 295 | 296 | # reverse cyclic shift 297 | if self.shift_size > 0: 298 | x = torch.roll(shifted_x, shifts=(self.shift_size), dims=(1)) 299 | else: 300 | x = shifted_x 301 | x = x.view(B, Lp, C) 302 | # reverse padding x 303 | x = x[:, :L, :].contiguous() 304 | x = shortcut + self.drop_path(self.norm1(x)) 305 | 306 | # FFN 307 | x = x + self.drop_path(self.norm2(self.mlp(x))) 308 | 309 | return x 310 | 311 | 312 | class PatchMerging(nn.Module): 313 | """ Patch Merging Layer 314 | Args: 315 | dim (int): Number of input channels. 316 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 317 | """ 318 | def __init__(self, dim, norm_layer=nn.LayerNorm): 319 | super().__init__() 320 | self.dim = dim 321 | # self.reduction = nn.Linear(2 * dim, dim, bias=False) 322 | # self.norm = norm_layer(2 * dim) 323 | 324 | def forward(self, x): 325 | """ Forward function. 326 | Args: 327 | x: Input feature, tensor size (B, L, C). 328 | """ 329 | B, L, C = x.shape 330 | 331 | # padding 332 | # pad_input = (L % 2 == 1) 333 | # if pad_input: 334 | # x = F.pad(x, (0, 0, 0, L % 2)) 335 | x = F.pad(x, (0, 0, 0, L % 2)) 336 | 337 | x0 = x[:, 0::2, :] # B L/2 C 338 | x1 = x[:, 1::2, :] # B L/2 C 339 | # x = torch.cat([x0, x1], -1) # B L/2 2*C 340 | 341 | # x = self.norm(x) 342 | # x = self.reduction(x) 343 | x = torch.maximum(x0, x1) 344 | 345 | return x 346 | 347 | 348 | class BasicLayer(nn.Module): 349 | """ A basic Swin Transformer layer for one stage. 350 | Args: 351 | dim (int): Number of input channels. 352 | input_resolution (int): Input resolution. 353 | depth (int): Number of blocks. 354 | num_heads (int): Number of attention heads. 355 | window_size (int): Local window size. 356 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 357 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 358 | drop (float, optional): Dropout rate. Default: 0.0 359 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 360 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 361 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 362 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 363 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 364 | pretrained_window_size (int): Local window size in pre-training. 365 | """ 366 | 367 | def __init__(self, dim, depth, num_heads, window_size, 368 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 369 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 370 | pretrained_window_size=0): 371 | 372 | super().__init__() 373 | self.dim = dim 374 | self.depth = depth 375 | self.use_checkpoint = use_checkpoint 376 | 377 | # build blocks 378 | self.blocks = nn.ModuleList([ 379 | SwinTransformerBlock_1D(dim=dim, 380 | num_heads=num_heads, window_size=window_size, 381 | shift_size=0 if (i % 2 == 0) else window_size // 2, 382 | mlp_ratio=mlp_ratio, 383 | qkv_bias=qkv_bias, 384 | drop=drop, attn_drop=attn_drop, 385 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 386 | norm_layer=norm_layer, 387 | pretrained_window_size=pretrained_window_size) 388 | for i in range(depth)]) 389 | 390 | # patch merging layer 391 | if downsample is not None: 392 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 393 | else: 394 | self.downsample = None 395 | 396 | def forward(self, x): 397 | for blk in self.blocks: 398 | if self.use_checkpoint: 399 | x = checkpoint.checkpoint(blk, x) 400 | else: 401 | x = blk(x) 402 | 403 | proposal = x 404 | if self.downsample is not None: 405 | x = self.downsample(x) 406 | return x, proposal 407 | 408 | def _init_respostnorm(self): 409 | for blk in self.blocks: 410 | nn.init.constant_(blk.norm1.bias, 0) 411 | nn.init.constant_(blk.norm1.weight, 0) 412 | nn.init.constant_(blk.norm2.bias, 0) 413 | nn.init.constant_(blk.norm2.weight, 0) 414 | 415 | 416 | class PatchEmbed1D(nn.Module): 417 | """ Video to Patch Embedding. 418 | Args: 419 | seq_len (int): Sequence length. 420 | patch_size (int): Patch token size. Default: 4. 421 | in_chans (int): Number of input video channels. Default: 3. 422 | embed_dim (int): Number of linear projection output channels. Default: 96. 423 | norm_layer (nn.Module, optional): Normalization layer. Default: None 424 | """ 425 | def __init__(self, patch_size=4, in_chans=32, embed_dim=128, norm_layer=None): 426 | super().__init__() 427 | self.patch_size = patch_size 428 | 429 | self.in_chans = in_chans 430 | self.embed_dim = embed_dim 431 | 432 | self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 433 | 434 | if norm_layer is not None: 435 | self.norm = norm_layer(embed_dim) 436 | else: 437 | self.norm = None 438 | 439 | def forward(self, x): 440 | """Forward function.""" 441 | # padding 442 | _, _, L = x.size() 443 | pad_r = (self.patch_size - L % self.patch_size) % self.patch_size 444 | x = F.pad(x, (0, pad_r)) 445 | x = self.proj(x) # B C Wl 446 | if self.norm is not None: 447 | # Wl = x.size(2) 448 | x = x.transpose(1, 2) 449 | x = self.norm(x) 450 | # x = x.transpose(1, 2).view(-1, self.embed_dim, Wl) 451 | 452 | return x 453 | 454 | 455 | class SwinTransformerV2_1D(nn.Module): 456 | 457 | def __init__(self, 458 | patch_size=4, 459 | in_chans=32, 460 | embed_dim=96, 461 | depths=[2, 2, 6, 2], 462 | num_heads=[3, 6, 12, 24], 463 | window_size=7, 464 | mlp_ratio=4., 465 | qkv_bias=True, 466 | drop_rate=0., 467 | attn_drop_rate=0., 468 | drop_path_rate=0.1, 469 | norm_layer=nn.LayerNorm, 470 | patch_norm=True, 471 | use_checkpoint=False, 472 | pretrained_window_sizes=[0, 0, 0, 0], 473 | **kwargs): 474 | super().__init__() 475 | 476 | self.num_layers = len(depths) 477 | self.embed_dim = embed_dim 478 | self.patch_norm = patch_norm 479 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 480 | self.mlp_ratio = mlp_ratio 481 | 482 | # split image into non-overlapping patches 483 | self.patch_embed = PatchEmbed1D( 484 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 485 | norm_layer=norm_layer if self.patch_norm else None) 486 | 487 | self.pos_drop = nn.Dropout(p=drop_rate) 488 | # stochastic depth 489 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 490 | 491 | # build layers 492 | self.layers = nn.ModuleList() 493 | for i_layer in range(self.num_layers): 494 | layer = BasicLayer(dim=embed_dim, 495 | depth=depths[i_layer], 496 | num_heads=num_heads[i_layer], 497 | window_size=window_size[i_layer], 498 | mlp_ratio=self.mlp_ratio, 499 | qkv_bias=qkv_bias, 500 | drop=drop_rate, attn_drop=attn_drop_rate, 501 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 502 | norm_layer=norm_layer, 503 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 504 | use_checkpoint=use_checkpoint, 505 | pretrained_window_size=pretrained_window_sizes[i_layer] 506 | ) 507 | self.layers.append(layer) 508 | 509 | self.apply(self._init_weights) 510 | for bly in self.layers: 511 | bly._init_respostnorm() 512 | 513 | def _init_weights(self, m): 514 | if isinstance(m, nn.Linear): 515 | trunc_normal_(m.weight, std=.02) 516 | if isinstance(m, nn.Linear) and m.bias is not None: 517 | nn.init.constant_(m.bias, 0) 518 | elif isinstance(m, nn.LayerNorm): 519 | nn.init.constant_(m.bias, 0) 520 | nn.init.constant_(m.weight, 1.0) 521 | 522 | @torch.jit.ignore 523 | def no_weight_decay(self): 524 | return {'absolute_pos_embed'} 525 | 526 | @torch.jit.ignore 527 | def no_weight_decay_keywords(self): 528 | return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} 529 | 530 | def forward_features(self, x): 531 | x = self.patch_embed(x) 532 | x = self.pos_drop(x) 533 | 534 | proposals = list() 535 | for layer in self.layers: 536 | x, proposal = layer(x) 537 | proposals.append(proposal) 538 | 539 | return proposals 540 | 541 | def forward(self, x): 542 | return self.forward_features(x) --------------------------------------------------------------------------------