├── utils ├── __init__.py ├── logger.py ├── imports.py ├── metric_logger.py ├── misc.py ├── video_list.py ├── model_serialization.py ├── box_utils.py ├── comm.py ├── checkpoint.py └── bounding_box.py ├── config ├── __init__.py └── defaults.py ├── assets ├── idea.png └── framework.png ├── models ├── .DS_Store ├── grounding_model │ ├── __init__.py │ ├── position_encoding.py │ └── modal_encoder.py ├── language_model │ ├── __init__.py │ ├── lstm.py │ └── bert.py ├── vision_model │ ├── __init__.py │ ├── position_encoding.py │ └── backbone.py ├── __init__.py ├── post_processor.py ├── net_utils.py ├── pipeline.py ├── bert_model │ └── bert_module.py ├── map2d_head.py └── criterion.py ├── datasets ├── __init__.py ├── samplers │ ├── __init__.py │ ├── iteration_based_batch_sampler.py │ └── grouped_batch_sampler.py ├── collate_batch.py ├── evaluation │ ├── __init__.py │ ├── hcstvg_eval.py │ └── vidstg_eval.py ├── gaussion_hm.py ├── words.py ├── build.py ├── transforms.py └── hcstvg.py ├── requirements.txt ├── engine ├── __init__.py ├── optimizer.py ├── evaluate.py └── lr_scheduler.py ├── experiments ├── hcstvg2.yaml ├── hcstvg.yaml └── vidstg.yaml └── scripts └── test_net.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg -------------------------------------------------------------------------------- /assets/idea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HengLan/CGSTVG/HEAD/assets/idea.png -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HengLan/CGSTVG/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HengLan/CGSTVG/HEAD/assets/framework.png -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import make_data_loader, build_transforms, build_dataset 2 | from .evaluation import build_evaluator -------------------------------------------------------------------------------- /datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .grouped_batch_sampler import GroupedBatchSampler 2 | from .iteration_based_batch_sampler import IterationBasedBatchSampler 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | yacs 3 | torchtext==0.15.2 4 | timm 5 | tensorboard 6 | pytorch-pretrained-bert 7 | transformers==4.5.1 8 | Pillow 9 | opencv_python 10 | ffmpeg_python 11 | scipy 12 | cython 13 | tqdm 14 | packaging==21.3 15 | ftfy -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import make_optimizer 2 | from .optimizer import make_lr_scheduler, update_ema 3 | from .lr_scheduler import WarmupMultiStepLR, WarmupReduceLROnPlateau, WarmupPolyLR, adjust_learning_rate 4 | from .evaluate import do_eval 5 | -------------------------------------------------------------------------------- /models/grounding_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .modal_encoder import CrossModalEncoder 2 | from .query_decoder import QueryDecoder 3 | 4 | def build_encoder(cfg): 5 | return CrossModalEncoder(cfg) 6 | 7 | def build_decoder(cfg): 8 | return QueryDecoder(cfg) -------------------------------------------------------------------------------- /datasets/collate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.misc import NestedTensor 4 | 5 | 6 | def collate_fn(batch): 7 | transposed_batch = list(zip(*batch)) 8 | videos = transposed_batch[0] 9 | texts = transposed_batch[1] 10 | targets = transposed_batch[2] 11 | 12 | batch_dict = {} 13 | batch_dict['durations'] = [video.shape[0] for video in videos] 14 | batch_dict['videos'] = NestedTensor.from_tensor_list(videos) 15 | batch_dict['texts'] = [text for text in texts] 16 | batch_dict['targets'] = [target for target in targets] 17 | 18 | return batch_dict 19 | -------------------------------------------------------------------------------- /models/language_model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .bert import BERT, Roberta 3 | from .lstm import RNNEncoder 4 | 5 | def build_text_encoder(cfg): 6 | if cfg.MODEL.USE_LSTM: 7 | language_encoder = RNNEncoder( 8 | cfg.GLOVE_DIR, 9 | cfg.MODEL.LSTM.HIDDEN_SIZE // 2 if cfg.MODEL.LSTM.BIDIRECTIONAL \ 10 | else cfg.MODE.LSTM.HIDDEN_SIZE, 11 | cfg.MODEL.LSTM.BIDIRECTIONAL, 12 | cfg.MODEL.LSTM.DROPOUT, 13 | cfg.MODEL.LSTM_NUM_LAYERS, 14 | cfg.MODEL.LSTM.NAME 15 | ) 16 | else: 17 | language_encoder = Roberta( 18 | cfg.MODEL.TEXT_MODEL.NAME, 19 | cfg.MODEL.CG.HIDDEN, 20 | cfg.MODEL.TEXT_MODEL.FREEZE 21 | ) 22 | return language_encoder -------------------------------------------------------------------------------- /datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .vidstg_eval import VidSTGEvaluator 2 | from .hcstvg_eval import HCSTVGEvaluator 3 | 4 | def build_evaluator(cfg, logger, mode): 5 | if cfg.DATASET.NAME == 'VidSTG': 6 | return VidSTGEvaluator( 7 | logger, 8 | cfg.DATA_DIR, 9 | mode, 10 | iou_thresholds=[0.3, 0.5], 11 | save_pred=(mode=='test'), 12 | save_dir=cfg.OUTPUT_DIR, 13 | ) 14 | elif cfg.DATASET.NAME == 'HC-STVG': 15 | return HCSTVGEvaluator( 16 | logger, 17 | cfg.DATA_DIR, 18 | mode, 19 | iou_thresholds=[0.3, 0.5], 20 | save_pred=(mode=='test'), 21 | save_dir=cfg.OUTPUT_DIR, 22 | ) 23 | else: 24 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | # don't log results for the non-master process 9 | if distributed_rank > 0: 10 | return logger 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 14 | ch.setFormatter(formatter) 15 | logger.addHandler(ch) 16 | 17 | if save_dir: 18 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 19 | fh.setLevel(logging.DEBUG) 20 | fh.setFormatter(formatter) 21 | logger.addHandler(fh) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /models/vision_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .position_encoding import build_position_encoding 2 | from .backbone import GroupNormBackbone, Backbone, Joiner 3 | 4 | 5 | def build_vis_encoder(cfg): 6 | position_embedding = build_position_encoding(cfg) 7 | train_backbone = cfg.SOLVER.VIS_BACKBONE_LR > 0 8 | backbone_name = cfg.MODEL.VISION_BACKBONE.NAME 9 | if backbone_name in ("resnet50-gn", "resnet101-gn"): 10 | backbone = GroupNormBackbone( 11 | backbone_name, 12 | train_backbone, 13 | False, 14 | cfg.MODEL.VISION_BACKBONE.DILATION 15 | ) 16 | else: 17 | backbone = Backbone( 18 | backbone_name, 19 | train_backbone, 20 | False, 21 | cfg.MODEL.VISION_BACKBONE.DILATION 22 | ) 23 | model = Joiner(backbone, position_embedding) 24 | model.num_channels = backbone.num_channels 25 | return model -------------------------------------------------------------------------------- /utils/imports.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | if torch._six.PY3: 5 | import importlib 6 | import importlib.util 7 | import sys 8 | 9 | 10 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa 11 | def import_file(module_name, file_path, make_importable=False): 12 | spec = importlib.util.spec_from_file_location(module_name, file_path) 13 | module = importlib.util.module_from_spec(spec) 14 | spec.loader.exec_module(module) 15 | if make_importable: 16 | sys.modules[module_name] = module 17 | return module 18 | else: 19 | import imp 20 | 21 | def import_file(module_name, file_path, make_importable=None): 22 | module = imp.load_source(module_name, file_path) 23 | return module 24 | -------------------------------------------------------------------------------- /experiments/hcstvg2.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: data/hc-stvg2/checkpoints/ 2 | DATA_DIR: data/hc-stvg2/ 3 | TENSORBOARD_DIR: data/hc-stvg2/checkpoints/ 4 | 5 | INPUT: 6 | RESOLUTION: 420 7 | FLIP_PROB_TRAIN: 0.7 8 | TEMP_CROP_PROB: 0.5 9 | SAMPLE_FPS: 3.2 10 | 11 | MODEL: 12 | WEIGHT: model_zoo/pretrained_resnet101_checkpoint.pth 13 | VISION_BACKBONE: 14 | NAME: resnet101 15 | POS_ENC: sine 16 | TEXT_MODEL: 17 | NAME: roberta-base 18 | CG: 19 | FROM_SCRATCH: True 20 | USE_LEARN_TIME_EMBED: False 21 | USE_ACTION: True 22 | TEMP_THETA: 0.5 23 | SPAT_GT_THETA: 0.7 24 | SPAT_THETA: 0.8 25 | 26 | DATASET: 27 | NAME: HC-STVG 28 | 29 | DATALOADER: 30 | NUM_WORKERS: 8 31 | ASPECT_RATIO_GROUPING: False 32 | 33 | SOLVER: 34 | MAX_EPOCH: 90 35 | BATCH_SIZE: 1 36 | BBOX_COEF: 5 37 | GIOU_COEF: 4 38 | TEMP_COEF: 10 39 | ATTN_COEF: 1 40 | CONF_COEF: 1 41 | ACTIONESS_COEF: 2 42 | EOS_COEF: 0.3 43 | SIGMA: 2.0 44 | BASE_LR: 3e-4 45 | TEXT_LR: 5e-5 46 | VIS_BACKBONE_LR: 2e-5 47 | TEMP_LR: 1e-4 48 | OPTIMIZER: adamw 49 | VAL_PERIOD: 500 50 | CHECKPOINT_PERIOD: 500 51 | SHUFFLE: True 52 | SCHEDULE: 53 | TYPE: multistep_with_warmup 54 | DROP_STEP: [50, 90] 55 | PRE_VAL: False 56 | -------------------------------------------------------------------------------- /experiments/hcstvg.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: data/hc-stvg/checkpoints/ 2 | DATA_DIR: data/hc-stvg/ 3 | TENSORBOARD_DIR: data/hc-stvg/checkpoints/tensorboard/ 4 | 5 | INPUT: 6 | RESOLUTION: 420 7 | FLIP_PROB_TRAIN: 0.7 8 | TEMP_CROP_PROB: 0.5 9 | SAMPLE_FPS: 3.2 10 | 11 | MODEL: 12 | WEIGHT: model_zoo/pretrained_resnet101_checkpoint.pth 13 | VISION_BACKBONE: 14 | NAME: resnet101 15 | POS_ENC: sine 16 | TEXT_MODEL: 17 | NAME: roberta-base 18 | CG: 19 | FROM_SCRATCH: True 20 | USE_LEARN_TIME_EMBED: False 21 | USE_ACTION: True 22 | TEMP_THETA: 0.5 23 | SPAT_GT_THETA: 0.8 24 | SPAT_THETA: 0.8 25 | 26 | DATASET: 27 | NAME: HC-STVG 28 | 29 | DATALOADER: 30 | NUM_WORKERS: 8 31 | ASPECT_RATIO_GROUPING: False 32 | 33 | SOLVER: 34 | MAX_EPOCH: 90 35 | BATCH_SIZE: 1 36 | BBOX_COEF: 5 37 | GIOU_COEF: 4 38 | TEMP_COEF: 10 39 | ATTN_COEF: 1 40 | CONF_COEF: 1 41 | ACTIONESS_COEF: 2 42 | EOS_COEF: 0.3 43 | SIGMA: 2.0 44 | BASE_LR: 3e-4 45 | TEXT_LR: 5e-5 46 | VIS_BACKBONE_LR: 2e-5 47 | TEMP_LR: 1e-4 48 | OPTIMIZER: adamw 49 | VAL_PERIOD: 500 50 | CHECKPOINT_PERIOD: 2000 51 | SHUFFLE: True 52 | SCHEDULE: 53 | TYPE: multistep_with_warmup 54 | DROP_STEP: [50, 90] 55 | PRE_VAL: False 56 | -------------------------------------------------------------------------------- /experiments/vidstg.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: data/vidstg/checkpoints/ 2 | DATA_DIR: data/vidstg/ 3 | TENSORBOARD_DIR: data/vidstg/checkpoints/tensorboard/ 4 | 5 | INPUT: 6 | RESOLUTION: 420 7 | FLIP_PROB_TRAIN: 0.7 8 | TEMP_CROP_PROB: 0.5 9 | TRAIN_SAMPLE_NUM: 64 10 | 11 | MODEL: 12 | WEIGHT: model_zoo/pretrained_resnet101_checkpoint.pth 13 | VISION_BACKBONE: 14 | NAME: resnet101 15 | POS_ENC: sine 16 | TEXT_MODEL: 17 | NAME: roberta-base 18 | CG: 19 | FROM_SCRATCH: True 20 | USE_LEARN_TIME_EMBED: False 21 | USE_ACTION: True 22 | TEMP_THETA: -0.5 23 | SPAT_GT_THETA: 0.5 24 | SPAT_THETA: 0.5 25 | 26 | DATASET: 27 | NAME: VidSTG 28 | 29 | DATALOADER: 30 | NUM_WORKERS: 8 31 | ASPECT_RATIO_GROUPING: False 32 | 33 | SOLVER: 34 | MAX_EPOCH: 10 35 | BATCH_SIZE: 1 36 | BBOX_COEF: 5 37 | GIOU_COEF: 3 38 | TEMP_COEF: 1 39 | ATTN_COEF: 1 40 | CONF_COEF: 1 41 | ACTIONESS_COEF: 2 42 | EOS_COEF: 0.3 43 | SIGMA: 2.0 44 | BASE_LR: 3e-4 45 | TEXT_LR: 5e-5 46 | VIS_BACKBONE_LR: 1e-5 47 | TEMP_LR: 1e-4 48 | OPTIMIZER: adamw 49 | VAL_PERIOD: 260000 50 | CHECKPOINT_PERIOD: 500 51 | SHUFFLE: True 52 | SCHEDULE: 53 | TYPE: multistep_with_warmup_all 54 | DROP_STEP: [8,10] 55 | PRE_VAL: False 56 | -------------------------------------------------------------------------------- /datasets/samplers/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch.utils.data.sampler import BatchSampler 3 | 4 | 5 | class IterationBasedBatchSampler(BatchSampler): 6 | """ 7 | Wraps a BatchSampler, resampling from it until 8 | a specified number of iterations have been sampled 9 | """ 10 | 11 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 12 | self.batch_sampler = batch_sampler 13 | self.num_iterations = num_iterations 14 | self.start_iter = start_iter 15 | 16 | def __iter__(self): 17 | iteration = self.start_iter 18 | while iteration <= self.num_iterations: 19 | # if the underlying sampler has a set_epoch method, like 20 | # DistributedSampler, used for making each process see 21 | # a different split of the dataset, then set it 22 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 23 | self.batch_sampler.sampler.set_epoch(iteration) 24 | for batch in self.batch_sampler: 25 | iteration += 1 26 | if iteration > self.num_iterations: 27 | break 28 | yield batch 29 | 30 | def __len__(self): 31 | return self.num_iterations 32 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import CGSTVG 2 | from .criterion import VideoSTGLoss 3 | from .post_processor import PostProcess 4 | 5 | def build_model(cfg): 6 | """ 7 | Prepare the model architecture and 8 | """ 9 | model = CGSTVG(cfg) 10 | 11 | weight_dict = { 12 | "loss_bbox": cfg.SOLVER.BBOX_COEF, 13 | "loss_giou": cfg.SOLVER.GIOU_COEF, 14 | "loss_sted": cfg.SOLVER.TEMP_COEF, 15 | "loss_conf": cfg.SOLVER.CONF_COEF 16 | } 17 | 18 | if cfg.MODEL.CG.USE_ACTION: 19 | weight_dict["loss_actioness"] = cfg.SOLVER.ACTIONESS_COEF 20 | 21 | if cfg.SOLVER.USE_ATTN: 22 | weight_dict["loss_guided_attn"] = cfg.SOLVER.ATTN_COEF 23 | 24 | if cfg.SOLVER.USE_AUX_LOSS: 25 | aux_weight_dict = {} 26 | for i in range(cfg.MODEL.CG.DEC_LAYERS - 1): 27 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 28 | weight_dict.update(aux_weight_dict) 29 | 30 | losses = ["boxes", "sted", "conf"] 31 | if cfg.SOLVER.USE_ATTN: 32 | losses += ["guided_attn"] 33 | if cfg.MODEL.CG.USE_ACTION: 34 | losses += ["actioness"] 35 | 36 | loss_model = VideoSTGLoss(cfg, losses) 37 | 38 | return model, loss_model, weight_dict 39 | 40 | 41 | def build_postprocessors(): 42 | return PostProcess() -------------------------------------------------------------------------------- /models/post_processor.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 3 | from faulthandler import dump_traceback 4 | from typing import Dict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from utils.box_utils import box_cxcywh_to_xyxy 11 | 12 | 13 | class PostProcess(nn.Module): 14 | """ This module converts the model's output into the format expected by the coco api""" 15 | 16 | @torch.no_grad() 17 | def forward(self, outputs, target_sizes, frames_id, durations): 18 | """Perform the computation for inference evaluation 19 | """ 20 | out_sted, out_bbox = outputs["pred_sted"], outputs["pred_boxes"] 21 | assert len(out_bbox) == len(target_sizes) 22 | 23 | boxes = box_cxcywh_to_xyxy(out_bbox) 24 | img_h, img_w = target_sizes.unbind(1) 25 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 26 | pred_boxes = boxes * scale_fct 27 | # Avoid x1 y1 < 0 28 | pred_boxes = pred_boxes.clamp(min=0) 29 | 30 | b, t, _ = out_sted.shape 31 | device = out_sted.device 32 | temp_prob_map = torch.zeros(b,t,t).to(device) 33 | inf = -1e32 34 | for i_b in range(len(durations)): 35 | duration = durations[i_b] 36 | sted_prob = (torch.ones(t, t) * inf).tril(0).to(device) 37 | sted_prob[duration:,:] = inf 38 | sted_prob[:,duration:] = inf 39 | temp_prob_map[i_b,:,:] = sted_prob 40 | 41 | temp_prob_map += F.log_softmax(out_sted[:, :, 0], dim=1).unsqueeze(2) + \ 42 | F.log_softmax(out_sted[:, :, 1], dim=1).unsqueeze(1) 43 | 44 | pred_steds = [] 45 | for i_b in range(b): 46 | prob_map = temp_prob_map[i_b] # [T * T] 47 | frame_id_seq = frames_id[i_b] 48 | prob_seq = prob_map.flatten(0) 49 | max_tstamp = prob_seq.max(dim=0)[1].item() 50 | start_idx = max_tstamp // t 51 | end_idx = max_tstamp % t 52 | pred_sted = [frame_id_seq[start_idx], frame_id_seq[end_idx]+1] 53 | pred_steds.append(pred_sted) 54 | 55 | return pred_boxes, pred_steds 56 | 57 | -------------------------------------------------------------------------------- /models/grounding_model/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class SeqEmbeddingLearned(nn.Module): 8 | 9 | def __init__(self, num_pos_feats, d_model=256): 10 | super().__init__() 11 | self.embed = nn.Embedding(num_pos_feats, d_model) 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | nn.init.normal_(self.embed.weight) 16 | 17 | def forward(self, ln): 18 | return self.embed.weight[:ln].unsqueeze(1) 19 | 20 | 21 | class SeqEmbeddingSine(nn.Module): 22 | 23 | def __init__(self, max_len=200, d_model=512): 24 | super().__init__() 25 | self.max_len = max_len 26 | position = torch.arange(max_len).unsqueeze(1) 27 | div_term = torch.exp( 28 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 29 | ) 30 | te = torch.zeros(max_len, 1, d_model) 31 | te[:, 0, 0::2] = torch.sin(position * div_term) 32 | te[:, 0, 1::2] = torch.cos(position * div_term) 33 | self.register_buffer("te", te) 34 | 35 | def forward(self, ln): 36 | pos_t = self.te[:ln] 37 | return pos_t 38 | 39 | 40 | class PositionEmbeddingLearned(nn.Module): 41 | """ 42 | Absolute pos embedding, learned. 43 | """ 44 | 45 | def __init__(self, num_pos_feats=256): 46 | super().__init__() 47 | self.row_embed = nn.Embedding(50, num_pos_feats) 48 | self.col_embed = nn.Embedding(50, num_pos_feats) 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self): 52 | nn.init.uniform_(self.row_embed.weight) 53 | nn.init.uniform_(self.col_embed.weight) 54 | 55 | def forward(self, tensor_list): 56 | x = tensor_list.tensors 57 | h, w = x.shape[-2:] 58 | i = torch.arange(w, device=x.device) 59 | j = torch.arange(h, device=x.device) 60 | x_emb = self.col_embed(i) 61 | y_emb = self.row_embed(j) 62 | pos = ( 63 | torch.cat( 64 | [ 65 | x_emb.unsqueeze(0).repeat(h, 1, 1), 66 | y_emb.unsqueeze(1).repeat(1, w, 1), 67 | ], 68 | dim=-1, 69 | ) 70 | .permute(2, 0, 1) 71 | .unsqueeze(0) 72 | .repeat(x.shape[0], 1, 1, 1) 73 | ) 74 | return pos -------------------------------------------------------------------------------- /datasets/gaussion_hm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def gaussian_radius(det_size, min_overlap=0.7): 5 | height, width = det_size 6 | 7 | a1 = 1 8 | b1 = (height + width) 9 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 10 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 11 | r1 = (b1 + sq1) / 2 12 | 13 | a2 = 4 14 | b2 = 2 * (height + width) 15 | c2 = (1 - min_overlap) * width * height 16 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 17 | r2 = (b2 + sq2) / 2 18 | 19 | a3 = 4 * min_overlap 20 | b3 = -2 * min_overlap * (height + width) 21 | c3 = (min_overlap - 1) * width * height 22 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 23 | r3 = (b3 + sq3) / 2 24 | return min(r1, r2, r3) 25 | 26 | 27 | def gaussian2D(shape, sigma=1): 28 | m, n = [(ss - 1.) / 2. for ss in shape] 29 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 30 | 31 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 32 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 33 | return h 34 | 35 | 36 | def draw_umich_gaussian(heatmap, center, radius, k=1): 37 | diameter = 2 * radius + 1 38 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 39 | 40 | x, y = int(center[0]), int(center[1]) 41 | 42 | height, width = heatmap.shape[0:2] 43 | 44 | left, right = min(x, radius), min(width - x, radius + 1) 45 | top, bottom = min(y, radius), min(height - y, radius + 1) 46 | 47 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 48 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 49 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 50 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 51 | return heatmap 52 | 53 | 54 | def draw_msra_gaussian(heatmap, center, sigma): 55 | tmp_size = sigma * 3 56 | mu_x = int(center[0] + 0.5) 57 | mu_y = int(center[1] + 0.5) 58 | w, h = heatmap.shape[0], heatmap.shape[1] 59 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 60 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 61 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: 62 | return heatmap 63 | size = 2 * tmp_size + 1 64 | x = np.arange(0, size, 1, np.float32) 65 | y = x[:, np.newaxis] 66 | x0 = y0 = size // 2 67 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 68 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 69 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 70 | img_x = max(0, ul[0]), min(br[0], h) 71 | img_y = max(0, ul[1]), min(br[1], w) 72 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( 73 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], 74 | g[g_y[0]:g_y[1], g_x[0]:g_x[1]]) 75 | return heatmap 76 | -------------------------------------------------------------------------------- /utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from .comm import is_dist_avail_and_initialized 6 | 7 | 8 | class SmoothedValue: 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20, fmt=None): 14 | if fmt is None: 15 | fmt = "{median:.4f} ({global_avg:.4f})" 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 1e-12 19 | self.fmt = fmt 20 | 21 | def update(self, value, num=1): 22 | self.deque.append(value) 23 | self.count += num 24 | self.total += value * num 25 | 26 | def synchronize_between_processes(self): 27 | """ 28 | Distributed synchronization of the metric 29 | Warning: does not synchronize the deque! 30 | """ 31 | if not is_dist_avail_and_initialized(): 32 | return 33 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 34 | dist.barrier() 35 | dist.all_reduce(t) 36 | t = t.tolist() 37 | self.count = int(t[0]) 38 | self.total = t[1] 39 | 40 | @property 41 | def median(self): 42 | d = torch.tensor(list(self.deque)) 43 | return d.median().item() 44 | 45 | @property 46 | def avg(self): 47 | d = torch.tensor(list(self.deque), dtype=torch.float32) 48 | return d.mean().item() 49 | 50 | @property 51 | def global_avg(self): 52 | return self.total / self.count 53 | 54 | @property 55 | def max(self): 56 | return max(self.deque) 57 | 58 | @property 59 | def value(self): 60 | return self.deque[-1] 61 | 62 | def __str__(self): 63 | return self.fmt.format( 64 | median=self.median, 65 | avg=self.avg, 66 | global_avg=self.global_avg, 67 | max=self.max, 68 | value=self.value, 69 | ) 70 | 71 | 72 | class MetricLogger(object): 73 | def __init__(self, delimiter="\t"): 74 | self.meters = defaultdict(SmoothedValue) 75 | self.delimiter = delimiter 76 | 77 | def update(self, **kwargs): 78 | for k, v in kwargs.items(): 79 | if isinstance(v, torch.Tensor): 80 | v = v.item() 81 | assert isinstance(v, (float, int)) 82 | self.meters[k].update(v) 83 | 84 | def __getattr__(self, attr): 85 | if attr in self.meters: 86 | return self.meters[attr] 87 | if attr in self.__dict__: 88 | return self.__dict__[attr] 89 | raise AttributeError( 90 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 91 | ) 92 | 93 | def __str__(self): 94 | loss_str = [] 95 | for name, meter in self.meters.items(): 96 | loss_str.append("{}: {}".format(name, str(meter))) 97 | return self.delimiter.join(loss_str) 98 | 99 | def synchronize_between_processes(self): 100 | for meter in self.meters.values(): 101 | meter.synchronize_between_processes() 102 | 103 | def add_meter(self, name, meter): 104 | self.meters[name] = meter -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import random 4 | import torch 5 | import numpy as np 6 | import subprocess 7 | from .comm import is_main_process 8 | 9 | 10 | def mkdir(path): 11 | try: 12 | os.makedirs(path) 13 | except OSError as e: 14 | if e.errno != errno.EEXIST: 15 | raise 16 | 17 | 18 | def set_seed(seed): 19 | print("set seed ",seed) 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | 26 | def save_config(cfg, path): 27 | if is_main_process(): 28 | with open(path, 'w') as f: 29 | f.write(cfg.dump()) 30 | 31 | 32 | def to_device(targets, device): 33 | transfer_keys = set(['actioness', 'start_heatmap', 'end_heatmap', 'boxs', 'iou_map', 'candidates']) 34 | for idx in range(len(targets)): 35 | for key in targets[idx].keys(): 36 | if key in transfer_keys: 37 | targets[idx][key] = targets[idx][key].to(device) 38 | return targets 39 | 40 | 41 | class NestedTensor(object): 42 | def __init__(self, tensors, mask, durations): 43 | self.tensors = tensors 44 | self.mask = mask 45 | self.durations = durations 46 | 47 | def to(self, *args, **kwargs): 48 | cast_tensor = self.tensors.to(*args, **kwargs) 49 | cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None 50 | return type(self)(cast_tensor, cast_mask, self.durations) 51 | 52 | def decompose(self): 53 | return self.tensors, self.mask, self.durations 54 | 55 | def subsample(self, stride, start_idx=0): 56 | # Subsample the video for multi-modal Interaction 57 | sampled_tensors = [video[start_idx::stride] for video in \ 58 | torch.split(self.tensors, self.durations, dim=0)] 59 | sampled_mask = [mask[start_idx::stride] for mask in \ 60 | torch.split(self.mask, self.durations, dim=0)] 61 | 62 | sampled_durations = [tensor.shape[0] for tensor in sampled_tensors] 63 | 64 | return NestedTensor(torch.cat(sampled_tensors,dim=0), 65 | torch.cat(sampled_mask,dim=0), sampled_durations) 66 | 67 | @classmethod 68 | def from_tensor_list(cls, tensor_list): 69 | assert tensor_list[0].ndim == 4 # videos 70 | max_size = tuple(max(s) for s in zip(*[clip.shape for clip in tensor_list])) 71 | _, c, h, w = max_size 72 | 73 | dtype = tensor_list[0].dtype 74 | device = tensor_list[0].device 75 | 76 | # total number of frames in the batch 77 | durations = [clip.shape[0] for clip in tensor_list] 78 | nb_images = sum(clip.shape[0] for clip in tensor_list) 79 | tensor = torch.zeros((nb_images, c, h, w), dtype=dtype, device=device) 80 | mask = torch.ones((nb_images, h, w), dtype=torch.bool, device=device) 81 | cur_dur = 0 82 | for i_clip, clip in enumerate(tensor_list): 83 | tensor[ 84 | cur_dur : cur_dur + clip.shape[0], 85 | : clip.shape[1], 86 | : clip.shape[2], 87 | : clip.shape[3], 88 | ].copy_(clip) 89 | mask[ 90 | cur_dur : cur_dur + clip.shape[0], : clip.shape[2], : clip.shape[3] 91 | ] = False 92 | cur_dur += clip.shape[0] 93 | 94 | return cls(tensor, mask, durations) 95 | 96 | def __repr__(self): 97 | return repr(self.tensors) -------------------------------------------------------------------------------- /scripts/test_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | from config import cfg 8 | from utils.comm import synchronize, get_rank 9 | from utils.logger import setup_logger 10 | from utils.misc import mkdir, set_seed 11 | from utils.checkpoint import VSTGCheckpointer 12 | from datasets import make_data_loader, build_evaluator, build_dataset 13 | from models import build_model, build_postprocessors 14 | from engine import do_eval 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Spatio-Temporal Grounding Training") 19 | parser.add_argument( 20 | "--config-file", 21 | default="experiments/hcstvg.yaml", 22 | metavar="FILE", 23 | help="path to config file", 24 | type=str, 25 | ) 26 | parser.add_argument("--local-rank", type=int, default=0) 27 | parser.add_argument("--seed", type=int, default=42) 28 | parser.add_argument( 29 | "--use-seed", 30 | dest="use_seed", 31 | help="If use the random seed", 32 | action="store_true", 33 | ) 34 | parser.add_argument( 35 | "opts", 36 | help="Modify config options using the command-line", 37 | default=None, 38 | nargs=argparse.REMAINDER, 39 | ) 40 | 41 | args = parser.parse_args() 42 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 43 | args.distributed = num_gpus > 1 44 | 45 | if args.distributed: 46 | torch.cuda.set_device(args.local_rank) 47 | torch.distributed.init_process_group( 48 | backend="nccl", init_method="env://" 49 | ) 50 | synchronize() 51 | 52 | if args.config_file: 53 | cfg.merge_from_file(args.config_file) 54 | 55 | cfg.merge_from_list(args.opts) 56 | cfg.freeze() 57 | 58 | if args.use_seed: 59 | cudnn.benchmark = False 60 | cudnn.deterministic = True 61 | set_seed(args.seed + get_rank()) 62 | 63 | output_dir = cfg.OUTPUT_DIR 64 | if output_dir: 65 | mkdir(output_dir) 66 | 67 | logger = setup_logger("Video Grounding", output_dir, get_rank()) 68 | logger.info("Using {} GPUs".format(num_gpus)) 69 | logger.info(cfg) 70 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 71 | 72 | model, _, _ = build_model(cfg) 73 | device = torch.device(cfg.MODEL.DEVICE) 74 | model.to(device) 75 | 76 | checkpointer = VSTGCheckpointer(cfg, model, logger=logger, is_train=False) 77 | _ = checkpointer.load(cfg.MODEL.WEIGHT, with_optim=False) 78 | 79 | # Prepare the dataset cache 80 | if args.local_rank == 0: 81 | _ = build_dataset(cfg, split='test', transforms=None) 82 | 83 | synchronize() 84 | 85 | test_data_loader = make_data_loader( 86 | cfg, 87 | mode='test', 88 | is_distributed=args.distributed, 89 | ) 90 | 91 | logger.info("Start Testing") 92 | evaluator = build_evaluator(cfg, logger, mode='test') # mode = ['val','test'] 93 | postprocessor = build_postprocessors() 94 | do_eval( 95 | cfg, 96 | mode='test', 97 | logger=logger, 98 | model=model, 99 | postprocessor=postprocessor, 100 | data_loader=test_data_loader, 101 | evaluator=evaluator, 102 | device=device 103 | ) 104 | synchronize() 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /models/language_model/lstm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torchtext 7 | 8 | 9 | class RNNEncoder(nn.Module): 10 | def __init__(self, vocab_dir, hidden_size, bidirectional=False, 11 | dropout_p=0, n_layers=1, rnn_type='lstm'): 12 | super(RNNEncoder, self).__init__() 13 | 14 | vocab = load_vocab(vocab_dir) 15 | self.embedding = nn.Embedding.from_pretrained(vocab.vectors,freeze=True) # Froezen the embedding weight 16 | word_embed_size = vocab.vectors.shape[1] 17 | 18 | self.rnn_type = rnn_type 19 | self.rnn = getattr(nn, rnn_type.upper())(word_embed_size, hidden_size, n_layers, 20 | batch_first=True, 21 | bidirectional=bidirectional, 22 | dropout=dropout_p) 23 | self.num_dirs = 2 if bidirectional else 1 24 | self.variable_lengths = True 25 | 26 | def forward(self, text_data): 27 | """ 28 | Inputs: 29 | - input word_idx (batch, seq_len) 30 | Outputs: 31 | - output : Variable float (batch, max_len, hidden_size * num_dirs) 32 | - hidden : Variable float (batch, num_layers * num_dirs * hidden_size) 33 | - embedded: Variable float (batch, max_len, word_vec_size) 34 | """ 35 | text_tensors = text_data.tensors 36 | text_masks = text_data.mask 37 | 38 | input_lengths = (text_masks != 0).sum(1) # Variable (batch, ) 39 | input_lengths_list = input_lengths.data.cpu().numpy().tolist() 40 | 41 | sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() # list of sorted input_lengths 42 | sort_ixs = np.argsort(input_lengths_list)[::-1].tolist() # list of int sort_ixs, descending 43 | s2r = {s: r for r, s in enumerate(sort_ixs)} # O(n) 44 | recover_ixs = [s2r[s] for s in range(len(input_lengths_list))] # list of int recover ixs 45 | 46 | # move to long tensor 47 | sort_ixs = text_masks.data.new(sort_ixs).long() # Variable long 48 | recover_ixs = text_masks.data.new(recover_ixs).long() # Variable long 49 | 50 | # sort input_labels by descending order 51 | text_tensors = text_tensors[sort_ixs] 52 | text_masks = text_masks[sort_ixs] 53 | 54 | # embed 55 | embedded = self.embedding(text_tensors) # (n, seq_len, word_embedding_size) 56 | embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_input_lengths_list, batch_first=True) 57 | # forward rnn 58 | output, hidden = self.rnn(embedded) 59 | 60 | # recover embedded 61 | embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded, batch_first=True) 62 | embedded = embedded[recover_ixs] 63 | 64 | # recover rnn 65 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # (batch, max_len, hidden) 66 | output = output[recover_ixs] 67 | 68 | sent_output = [] 69 | for ii in range(output.shape[0]): 70 | sent_output.append(output[ii,int(input_lengths_list[ii]-1),:]) 71 | return torch.stack(sent_output, dim=0) 72 | 73 | 74 | def load_vocab(vocab_dir): 75 | vocab_pth = os.path.join(vocab_dir,'vocab.pth') 76 | if not os.path.exists(vocab_pth): 77 | vocab = torchtext.vocab.pretrained_aliases["glove.6B.300d"](cache=vocab_dir) 78 | vocab.itos.extend(['']) 79 | vocab.stoi[''] = vocab.vectors.shape[0] 80 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, vocab.dim)], dim=0) 81 | torch.save(vocab,vocab_pth) 82 | else: 83 | vocab = torch.load(vocab_pth) 84 | 85 | return vocab -------------------------------------------------------------------------------- /models/language_model/bert.py: -------------------------------------------------------------------------------- 1 | from cgitb import text 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from torch import nn 6 | from utils.video_list import NestedTensor 7 | 8 | from pytorch_pretrained_bert.modeling import BertModel 9 | from transformers import RobertaModel, RobertaTokenizerFast 10 | 11 | 12 | class BERT(nn.Module): 13 | def __init__(self, name: str, train_bert: bool, enc_num, pretrain_weight): 14 | super().__init__() 15 | if name == 'bert-base-uncased': 16 | self.num_channels = 768 17 | else: 18 | self.num_channels = 1024 19 | 20 | self.enc_num = enc_num 21 | self.bert = BertModel.from_pretrained(pretrain_weight) 22 | 23 | if not train_bert: 24 | for parameter in self.bert.parameters(): 25 | parameter.requires_grad_(False) 26 | 27 | def forward(self, tensor_list: NestedTensor): 28 | if self.enc_num > 0: 29 | all_encoder_layers, _ = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask) 30 | # use the output of the X-th transformer encoder layers 31 | xs = all_encoder_layers[self.enc_num - 1] 32 | else: 33 | xs = self.bert.embeddings.word_embeddings(tensor_list.tensors) 34 | 35 | mask = tensor_list.mask.to(torch.bool) 36 | mask = ~mask 37 | out = NestedTensor(xs, mask) 38 | 39 | return out 40 | 41 | 42 | class Roberta(nn.Module): 43 | def __init__(self, name, outdim, freeze=False) -> None: 44 | super().__init__() 45 | self.body = RobertaModel.from_pretrained("model_zoo/roberta-base/") 46 | self.tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_name_or_path='model_zoo/roberta/') 47 | 48 | if freeze: 49 | for p in self.body.parameters(): 50 | p.requires_grad_(False) 51 | 52 | config = self.body.config 53 | self.resizer = FeatureResizer( 54 | input_feat_size=config.hidden_size, 55 | output_feat_size=outdim, 56 | dropout=0.1, 57 | ) 58 | 59 | def forward(self, texts, device): 60 | tokenized = self.tokenizer.batch_encode_plus(texts, 61 | padding="longest", return_tensors="pt").to(device) 62 | encoded_text = self.body(**tokenized) 63 | text_cls = encoded_text.pooler_output 64 | 65 | # Transpose memory because pytorch's attention expects sequence first 66 | text_memory = encoded_text.last_hidden_state.transpose(0, 1) 67 | # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer 68 | text_attention_mask = tokenized.attention_mask.ne(1).bool() 69 | 70 | # Resize the encoder hidden states to be of the same d_model as the decoder 71 | text_memory_resized = self.resizer(text_memory) 72 | text_cls_resized = self.resizer(text_cls) 73 | 74 | return (text_attention_mask, text_memory_resized, tokenized), text_cls_resized 75 | 76 | 77 | class FeatureResizer(nn.Module): 78 | """ 79 | This class takes as input a set of embeddings of dimension C1 and outputs a set of 80 | embedding of dimension C2, after a linear transformation, dropout and normalization (LN). 81 | """ 82 | 83 | def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): 84 | super().__init__() 85 | self.do_ln = do_ln 86 | # Object feature encoding 87 | self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) 88 | self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) 89 | self.dropout = nn.Dropout(dropout) 90 | 91 | def forward(self, encoder_features): 92 | x = self.fc(encoder_features) 93 | if self.do_ln: 94 | x = self.layer_norm(x) 95 | output = self.dropout(x) 96 | return output -------------------------------------------------------------------------------- /engine/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .lr_scheduler import WarmupMultiStepLR, WarmupReduceLROnPlateau, WarmupPolyLR 3 | 4 | 5 | def update_ema(model, model_ema, decay): 6 | """Apply exponential moving average update. 7 | 8 | The weights are updated in-place as follow: 9 | w_ema = w_ema * decay + (1 - decay) * w 10 | Args: 11 | model: active model that is being optimized 12 | model_ema: running average model 13 | decay: exponential decay parameter 14 | """ 15 | with torch.no_grad(): 16 | if hasattr(model, "module"): 17 | # unwrapping DDP 18 | model = model.module 19 | msd = model.state_dict() 20 | for k, ema_v in model_ema.state_dict().items(): 21 | model_v = msd[k].detach() 22 | ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v) 23 | 24 | 25 | def make_optimizer(cfg, model, logger): 26 | vis_enc_param = [p for n, p in model.named_parameters() \ 27 | if (("vis_encoder" in n) and p.requires_grad)] 28 | text_enc_param = [p for n, p in model.named_parameters() \ 29 | if (("text_encoder" in n) and p.requires_grad)] 30 | temp_dec_param = [p for n, p in model.named_parameters() \ 31 | if (("ground_decoder.time_decoder" in n) and p.requires_grad)] 32 | rest_param = [p for n, p in model.named_parameters() if(('vis_encoder' not in n) and \ 33 | ('text_encoder' not in n) and ("ground_decoder.time_decoder" not in n) and p.requires_grad)] 34 | 35 | base_lr = cfg.SOLVER.BASE_LR 36 | optim_type = cfg.SOLVER.OPTIMIZER 37 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 38 | 39 | param_list = [ 40 | {"params" : rest_param}, 41 | {"params" : vis_enc_param, "lr" : cfg.SOLVER.VIS_BACKBONE_LR}, 42 | {"params" : text_enc_param, "lr" : cfg.SOLVER.TEXT_LR}, 43 | {"params" : temp_dec_param, "lr" : cfg.SOLVER.TEMP_LR}, 44 | ] 45 | 46 | # using RMSProp or AdamW 47 | if optim_type == 'rmsprop': 48 | optimizer = torch.optim.RMSprop(param_list, lr=base_lr, weight_decay=weight_decay) 49 | elif optim_type == 'adamw': 50 | optimizer = torch.optim.AdamW(param_list, lr=base_lr, weight_decay=weight_decay) 51 | elif optim_type == 'adam': 52 | optimizer = torch.optim.Adam(param_list, lr=base_lr, weight_decay=weight_decay) 53 | elif optim_type== 'sgd': 54 | optimizer = torch.optim.SGD(param_list, lr=base_lr, weight_decay=weight_decay, momentum=cfg.SOLVER.MOMENTUM) 55 | else: 56 | raise ValueError('Lr scheduler type not supportted ') 57 | 58 | return optimizer 59 | 60 | 61 | def make_lr_scheduler(cfg, optimizer, logger=None): 62 | if cfg.SOLVER.SCHEDULE.TYPE == "WarmupMultiStepLR": 63 | return WarmupMultiStepLR( 64 | optimizer, 65 | cfg.SOLVER.STEPS, 66 | cfg.SOLVER.GAMMA, 67 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 68 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 69 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 70 | ) 71 | 72 | elif cfg.SOLVER.SCHEDULE.TYPE == "WarmupReduceLROnPlateau": 73 | return WarmupReduceLROnPlateau( 74 | optimizer, 75 | cfg.SOLVER.SCHEDULE.FACTOR, 76 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 77 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 78 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 79 | patience=cfg.SOLVER.SCHEDULE.PATIENCE, 80 | threshold=cfg.SOLVER.SCHEDULE.THRESHOLD, 81 | cooldown=cfg.SOLVER.SCHEDULE.COOLDOWN, 82 | logger=logger, 83 | ) 84 | elif cfg.SOLVER.SCHEDULE.TYPE == "WarmupPolyLR": 85 | return WarmupPolyLR( 86 | optimizer, 87 | cfg.SOLVER.POWER, 88 | cfg.SOLVER.MAX_ITER, 89 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 90 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 91 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 92 | ) 93 | else: 94 | raise ValueError("Invalid Schedule Type") 95 | -------------------------------------------------------------------------------- /models/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.nn.functional as F 5 | from torchvision.ops.boxes import box_area 6 | 7 | 8 | class MLP(nn.Module): 9 | """Very simple multi-layer perceptron (also called FFN)""" 10 | 11 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0): 12 | super().__init__() 13 | self.num_layers = num_layers 14 | h = [hidden_dim] * (num_layers - 1) 15 | self.layers = nn.ModuleList( 16 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 17 | ) 18 | self.dropout = dropout 19 | if dropout: 20 | self.dropout = nn.Dropout(dropout) 21 | 22 | def forward(self, x): 23 | for i, layer in enumerate(self.layers): 24 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 25 | if self.dropout and i < self.num_layers: 26 | x = self.dropout(x) 27 | return x 28 | 29 | 30 | def gen_sineembed_for_position(pos_tensor): 31 | """ 32 | pos_tensor : [num_queries, batch_size, 4] 33 | """ 34 | scale = 2 * math.pi 35 | dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) 36 | dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / 128) 37 | x_embed = pos_tensor[:, :, 0] * scale 38 | y_embed = pos_tensor[:, :, 1] * scale 39 | pos_x = x_embed[:, :, None] / dim_t 40 | pos_y = y_embed[:, :, None] / dim_t 41 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 42 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 43 | if pos_tensor.size(-1) == 2: 44 | pos = torch.cat((pos_y, pos_x), dim=2) 45 | elif pos_tensor.size(-1) == 4: 46 | w_embed = pos_tensor[:, :, 2] * scale 47 | pos_w = w_embed[:, :, None] / dim_t 48 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) 49 | 50 | h_embed = pos_tensor[:, :, 3] * scale 51 | pos_h = h_embed[:, :, None] / dim_t 52 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) 53 | 54 | pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) 55 | else: 56 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) 57 | return pos 58 | 59 | 60 | def inverse_sigmoid(x, eps=1e-3): 61 | x = x.clamp(min=0, max=1) 62 | x1 = x.clamp(min=eps) 63 | x2 = (1 - x).clamp(min=eps) 64 | return torch.log(x1/x2) 65 | 66 | def greater_than_indices(tensor, n): 67 | indices = torch.nonzero(tensor > n, as_tuple=False) 68 | return indices 69 | 70 | def topk_index(tensor, conf, top): 71 | tensor = tensor[tensor >= conf] 72 | if len(tensor) <= top: 73 | return torch.argsort(tensor, descending=True) 74 | return torch.argsort(tensor, descending=True)[:top] 75 | 76 | def box_cxcywh_to_xyxy(x): 77 | x_c, y_c, w, h = x.unbind(-1) 78 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 79 | return torch.stack(b, dim=-1) 80 | 81 | def box_iou(boxes1, boxes2): 82 | area1 = box_area(boxes1) 83 | area2 = box_area(boxes2) 84 | 85 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 86 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 87 | 88 | wh = (rb - lt).clamp(min=0) # [N,M,2] 89 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 90 | 91 | union = area1[:, None] + area2 - inter 92 | 93 | iou = inter / union 94 | return iou, union 95 | 96 | def generalized_box_iou(boxes1, boxes2): 97 | """ 98 | Generalized IoU from https://giou.stanford.edu/ 99 | 100 | The boxes should be in [x0, y0, x1, y1] format 101 | 102 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 103 | and M = len(boxes2) 104 | """ 105 | # degenerate boxes gives inf / nan results 106 | # so do an early check 107 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 108 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 109 | iou, union = box_iou(boxes1, boxes2) 110 | 111 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 112 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 113 | 114 | wh = (rb - lt).clamp(min=0) # [N,M,2] 115 | area = wh[:, :, 0] * wh[:, :, 1] 116 | 117 | return iou - (area - union) / area -------------------------------------------------------------------------------- /utils/video_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | 6 | class NestedTensor(object): 7 | def __init__(self, tensors, mask): 8 | self.tensors = tensors 9 | self.mask = mask 10 | 11 | def to(self, device): 12 | cast_tensor = self.tensors.to(device) 13 | mask = self.mask 14 | if mask is not None: 15 | assert mask is not None 16 | cast_mask = mask.to(device) 17 | else: 18 | cast_mask = None 19 | return NestedTensor(cast_tensor, cast_mask) 20 | 21 | def decompose(self): 22 | return self.tensors, self.mask 23 | 24 | def __repr__(self): 25 | return 'NestedTensor{}'.format(self.tensors.shape) 26 | 27 | 28 | class TargetTensor(object): 29 | 30 | def __init__(self, target : dict) -> None: 31 | self.spatial_hm = target['spatial_heatmap'] 32 | self.wh = target['wh'] 33 | self.offset = target['offset'] 34 | self.actioness = target['actioness'] 35 | self.start_hm = target['start_heatmap'] 36 | self.end_hm = target['end_heatmap'] 37 | self.target = target 38 | 39 | def to(self,device): 40 | return TargetTensor( 41 | { 42 | 'spatial_heatmap' : self.spatial_hm.to(device), 43 | 'wh' : self.wh.to(device), 44 | 'offset' : self.offset.to(device), 45 | 'actioness' : self.actioness.to(device), 46 | 'start_heatmap' : self.spatial_hm.to(device), 47 | 'end_heatmap' : self.end_hm.to(device) 48 | } 49 | ) 50 | 51 | 52 | class VideoList(object): 53 | """ 54 | Structure that holds a list of videos (of possibly 55 | varying sizes) as a single tensor. 56 | This works by padding the images to the same size, 57 | and storing in a field the original sizes of each video 58 | """ 59 | 60 | def __init__(self, tensors, video_sizes): 61 | """ 62 | Arguments: 63 | tensors (tensor) 64 | video_sizes (list[tuple[int, int]]) 65 | """ 66 | self.tensors = tensors 67 | self.video_sizes = video_sizes 68 | 69 | def to(self, *args, **kwargs): 70 | cast_tensor = self.tensors.to(*args, **kwargs) 71 | return VideoList(cast_tensor, self.video_sizes) 72 | 73 | 74 | def to_video_list(tensors, size_divisible=0): 75 | """ 76 | tensors can be an VideoList, a torch.Tensor or 77 | an iterable of Tensors. It can't be a numpy array. 78 | When tensors is an iterable of Tensors, it pads 79 | the Tensors with zeros so that they have the same 80 | shape 81 | """ 82 | 83 | if isinstance(tensors, torch.Tensor) and size_divisible > 0: 84 | tensors = [tensors] 85 | 86 | if isinstance(tensors, VideoList): 87 | return tensors 88 | elif isinstance(tensors, torch.Tensor): 89 | # single tensor shape can be inferred 90 | if tensors.dim() == 4: # T * C * H * W 91 | tensors = tensors[None] 92 | assert tensors.dim() == 5 93 | image_sizes = [tensor.shape[-2:] for tensor in tensors] 94 | return VideoList(tensors, image_sizes) 95 | 96 | elif isinstance(tensors, (tuple, list)): 97 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 98 | 99 | # TODO Ideally, just remove this and let me model handle arbitrary 100 | # input sizs 101 | if size_divisible > 0: 102 | import math 103 | 104 | stride = size_divisible 105 | max_size = list(max_size) 106 | max_size[2] = int(math.ceil(max_size[2] / stride) * stride) 107 | max_size[3] = int(math.ceil(max_size[3] / stride) * stride) 108 | max_size = tuple(max_size) 109 | 110 | batch_shape = (len(tensors),) + max_size 111 | batched_imgs = tensors[0].new(*batch_shape).zero_() 112 | for img, pad_img in zip(tensors, batched_imgs): 113 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2], : img.shape[3]].copy_(img) 114 | 115 | image_sizes = [im.shape[-2:] for im in tensors] 116 | 117 | return VideoList(batched_imgs, image_sizes) 118 | else: 119 | raise TypeError("Unsupported type for to_video_list: {}".format(type(tensors))) 120 | -------------------------------------------------------------------------------- /models/pipeline.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .net_utils import MLP 3 | from .vision_model import build_vis_encoder 4 | from .language_model import build_text_encoder 5 | from .grounding_model import build_encoder, build_decoder 6 | from utils.misc import NestedTensor 7 | from .vidswin.video_swin_transformer import vidswin_model 8 | 9 | class CGSTVG(nn.Module): 10 | def __init__(self, cfg): 11 | super(CGSTVG, self).__init__() 12 | self.cfg = cfg.clone() 13 | self.max_video_len = cfg.INPUT.MAX_VIDEO_LEN 14 | self.use_attn = cfg.SOLVER.USE_ATTN 15 | 16 | self.use_aux_loss = cfg.SOLVER.USE_AUX_LOSS # use the output of each transformer layer 17 | self.use_actioness = cfg.MODEL.CG.USE_ACTION 18 | self.query_dim = cfg.MODEL.CG.QUERY_DIM 19 | 20 | self.vis_encoder = build_vis_encoder(cfg) 21 | vis_fea_dim = self.vis_encoder.num_channels 22 | 23 | self.text_encoder = build_text_encoder(cfg) 24 | 25 | self.ground_encoder = build_encoder(cfg) 26 | self.ground_decoder = build_decoder(cfg) 27 | 28 | hidden_dim = cfg.MODEL.CG.HIDDEN 29 | self.input_proj = nn.Conv2d(vis_fea_dim, hidden_dim, kernel_size=1) 30 | self.temp_embed = MLP(hidden_dim, hidden_dim, 2, 2, dropout=0.3) 31 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 32 | 33 | self.vid = vidswin_model("video_swin_t_p4w7", "video_swin_t_p4w7_k400_1k") 34 | self.input_proj2 = nn.Conv2d(768, hidden_dim, kernel_size=1) 35 | for param in self.vid.parameters(): 36 | param.requires_grad = False 37 | 38 | self.action_embed = None 39 | if self.use_actioness: 40 | self.action_embed = MLP(hidden_dim, hidden_dim, 1, 2, dropout=0.3) 41 | 42 | self.ground_decoder.time_embed2 = self.action_embed 43 | 44 | # add the iteration anchor update 45 | self.ground_decoder.decoder.bbox_embed = self.bbox_embed 46 | 47 | def forward(self, videos, texts, targets, iteration_rate=-1): 48 | # Visual Feature 49 | vis_outputs, vis_pos_embed = self.vis_encoder(videos) 50 | vis_features, vis_mask, vis_durations = vis_outputs.decompose() 51 | vis_features = self.input_proj(vis_features) 52 | vis_outputs = NestedTensor(vis_features, vis_mask, vis_durations) 53 | 54 | vid_features = self.vid(videos.tensors, len(videos.tensors)) 55 | vid_features = self.input_proj2(vid_features['3']) 56 | 57 | # Textual Feature 58 | device = vis_features.device 59 | text_outputs, _ = self.text_encoder(texts, device) 60 | 61 | # Multimodal Feature Encoding 62 | encoded_info = self.ground_encoder(videos=vis_outputs, vis_pos=vis_pos_embed, texts=text_outputs, vid_features=vid_features) 63 | encoded_info["iteration_rate"] = iteration_rate 64 | encoded_info["videos"] = videos 65 | # Query-based Decoding 66 | outputs_pos, outputs_time = self.ground_decoder(encoded_info=encoded_info, vis_pos=vis_pos_embed, targets=targets) 67 | 68 | out = {} 69 | 70 | # the final decoder embeddings and the refer anchors 71 | ############### predict bounding box ############### 72 | refer_anchors, anchors_conf, fake_anchors = outputs_pos 73 | outputs_coord = refer_anchors.flatten(1,2) # [num_layers, T, 4] 74 | out.update({"pred_boxes": outputs_coord[-1]}) 75 | out.update({"boxes_conf": anchors_conf[-1]}) 76 | out.update({"fake_boxes": fake_anchors}) 77 | ###################################################### 78 | 79 | ####### predict the start and end probability ####### 80 | time_hiden_state = outputs_time 81 | outputs_time = self.temp_embed(time_hiden_state) # [num_layers, b, T, 2] 82 | out.update({"pred_sted": outputs_time[-1]}) 83 | ####################################################### 84 | 85 | if self.use_actioness: 86 | outputs_actioness = self.action_embed(time_hiden_state) # [num_layers, b, T, 1] 87 | out.update({"pred_actioness": outputs_actioness[-1]}) 88 | 89 | if self.use_aux_loss: 90 | out["aux_outputs"] = [ 91 | { 92 | "pred_sted": a, 93 | "pred_boxes": b, 94 | "boxes_conf": c 95 | } 96 | for a, b, c in zip(outputs_time[:-1], outputs_coord[:-1], anchors_conf[:-1]) 97 | ] 98 | for i_aux in range(len(out["aux_outputs"])): 99 | if self.use_actioness: 100 | out["aux_outputs"][i_aux]["pred_actioness"] = outputs_actioness[i_aux] 101 | 102 | return out -------------------------------------------------------------------------------- /utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | 7 | 8 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict, load_mapping, logger): 9 | """ 10 | Strategy: suppose that the models that we will create will have prefixes appended 11 | to each of its keys, for example due to an extra level of nesting that the original 12 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 13 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 14 | res2.conv1.weight. We thus want to match both parameters together. 15 | For that, we look for each model weight, look among all loaded keys if there is one 16 | that is a suffix of the current weight name, and use it if that's the case. 17 | If multiple matches exist, take the one with longest size 18 | of the corresponding name. For example, for the same model as before, the pretrained 19 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 20 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 21 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 22 | """ 23 | current_keys = sorted(list(model_state_dict.keys())) 24 | loaded_keys = sorted(list(loaded_state_dict.keys())) 25 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 26 | # loaded_key string, if it matches 27 | # NOTE: Kaihua Tang, since some modules of current model will be initialized from assigned layer of 28 | # loaded model, we use load_mapping to do such operation 29 | mapped_current_keys = current_keys.copy() 30 | for i, key in enumerate(mapped_current_keys): 31 | for source_key, target_key in load_mapping.items(): 32 | if source_key in key: 33 | mapped_current_keys[i] = key.replace(source_key, target_key) 34 | logger.info("MAPPING {} in current model to {} in loaded model.".format(key, mapped_current_keys[i])) 35 | 36 | match_matrix = [ 37 | len(j) if i.endswith(j) else 0 for i in mapped_current_keys for j in loaded_keys 38 | ] 39 | match_matrix = torch.as_tensor(match_matrix).view( 40 | len(current_keys), len(loaded_keys) 41 | ) 42 | 43 | max_match_size, idxs = match_matrix.max(1) 44 | # remove indices that correspond to no-match 45 | idxs[max_match_size == 0] = -1 46 | 47 | # used for logging 48 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 49 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 50 | log_str_template = "REMATCHING! {: <{}} loaded from {: <{}} of shape {}" 51 | 52 | for idx_new, idx_old in enumerate(idxs.tolist()): 53 | if idx_old == -1: 54 | key = current_keys[idx_new] 55 | logger.info("NO-MATCHING of current module: {} of shape {}".format(key, 56 | tuple(model_state_dict[key].shape))) 57 | continue 58 | key = current_keys[idx_new] 59 | key_old = loaded_keys[idx_old] 60 | model_state_dict[key] = loaded_state_dict[key_old] 61 | # add a control gate for this logger (it's too large) 62 | 63 | if ((not key.startswith('module.')) and key != key_old) or (key.startswith('module.') and key[7:] != key_old): 64 | logger.info( 65 | log_str_template.format( 66 | key, 67 | max_size, 68 | key_old, 69 | max_size_loaded, 70 | tuple(loaded_state_dict[key_old].shape), 71 | ) 72 | ) 73 | 74 | 75 | def strip_prefix_if_present(state_dict, prefix): 76 | keys = sorted(state_dict.keys()) 77 | if not all(key.startswith(prefix) for key in keys): 78 | return state_dict 79 | stripped_state_dict = OrderedDict() 80 | for key, value in state_dict.items(): 81 | stripped_state_dict[key.replace(prefix, "")] = value 82 | return stripped_state_dict 83 | 84 | 85 | def load_state_dict(model, loaded_state_dict, load_mapping, logger): 86 | model_state_dict = model.state_dict() 87 | # if the state_dict comes from a model that was wrapped in a 88 | # DataParallel or DistributedDataParallel during serialization, 89 | # remove the "module" prefix before performing the matching 90 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 91 | align_and_update_state_dicts(model_state_dict, loaded_state_dict, load_mapping, logger) 92 | # use strict loading 93 | model.load_state_dict(model_state_dict) 94 | -------------------------------------------------------------------------------- /utils/box_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import numpy as np 5 | from torchvision.ops.boxes import box_area 6 | from typing import Tuple 7 | 8 | 9 | #### Bounding box utilities imported from torchvision and converted to numpy 10 | def np_box_area(boxes: np.array) -> np.array: 11 | """ 12 | Computes the area of a set of bounding boxes, which are specified by its 13 | (x1, y1, x2, y2) coordinates. 14 | 15 | Args: 16 | boxes (Tensor[N, 4]): boxes for which the area will be computed. They 17 | are expected to be in (x1, y1, x2, y2) format with 18 | ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 19 | 20 | Returns: 21 | area (Tensor[N]): area for each box 22 | """ 23 | assert boxes.ndim == 2 and boxes.shape[-1] == 4 24 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 25 | 26 | 27 | # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py 28 | # with slight modifications 29 | def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: 30 | area1 = np_box_area(boxes1) 31 | area2 = np_box_area(boxes2) 32 | 33 | lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clip(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | return inter, union 42 | 43 | 44 | def np_box_iou(boxes1: np.array, boxes2: np.array) -> np.array: 45 | """ 46 | Return intersection-over-union (Jaccard index) of boxes. 47 | 48 | Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with 49 | ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 50 | 51 | Args: 52 | boxes1 (Tensor[N, 4]) 53 | boxes2 (Tensor[M, 4]) 54 | 55 | Returns: 56 | iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 57 | """ 58 | inter, union = _box_inter_union(boxes1, boxes2) 59 | iou = inter / union 60 | return iou 61 | 62 | 63 | def box_cxcywh_to_xyxy(x): 64 | x_c, y_c, w, h = x.unbind(-1) 65 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 66 | return torch.stack(b, dim=-1) 67 | 68 | 69 | def box_xyxy_to_cxcywh(x): 70 | x0, y0, x1, y1 = x.unbind(-1) 71 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 72 | return torch.stack(b, dim=-1) 73 | 74 | 75 | # modified from torchvision to also return the union 76 | def box_iou(boxes1, boxes2): 77 | area1 = box_area(boxes1) 78 | area2 = box_area(boxes2) 79 | 80 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 81 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 82 | 83 | wh = (rb - lt).clamp(min=0) # [N,M,2] 84 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 85 | 86 | union = area1[:, None] + area2 - inter 87 | 88 | iou = inter / union 89 | return iou, union 90 | 91 | 92 | def generalized_box_iou(boxes1, boxes2): 93 | """ 94 | Generalized IoU from https://giou.stanford.edu/ 95 | 96 | The boxes should be in [x0, y0, x1, y1] format 97 | 98 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 99 | and M = len(boxes2) 100 | """ 101 | # degenerate boxes gives inf / nan results 102 | # so do an early check 103 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 104 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 105 | iou, union = box_iou(boxes1, boxes2) 106 | 107 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 108 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 109 | 110 | wh = (rb - lt).clamp(min=0) # [N,M,2] 111 | area = wh[:, :, 0] * wh[:, :, 1] 112 | 113 | return iou - (area - union) / area 114 | 115 | 116 | def masks_to_boxes(masks): 117 | """Compute the bounding boxes around the provided masks 118 | 119 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 120 | 121 | Returns a [N, 4] tensors, with the boxes in xyxy format 122 | """ 123 | if masks.numel() == 0: 124 | return torch.zeros((0, 4), device=masks.device) 125 | 126 | h, w = masks.shape[-2:] 127 | 128 | y = torch.arange(0, h, dtype=torch.float) 129 | x = torch.arange(0, w, dtype=torch.float) 130 | y, x = torch.meshgrid(y, x) 131 | 132 | x_mask = masks * x.unsqueeze(0) 133 | x_max = x_mask.flatten(1).max(-1)[0] 134 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 135 | 136 | y_mask = masks * y.unsqueeze(0) 137 | y_max = y_mask.flatten(1).max(-1)[0] 138 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 139 | 140 | return torch.stack([x_min, y_min, x_max, y_max], 1) 141 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def get_world_size(): 9 | if not dist.is_available(): 10 | return 1 11 | if not dist.is_initialized(): 12 | return 1 13 | return dist.get_world_size() 14 | 15 | 16 | def get_rank(): 17 | if not dist.is_available(): 18 | return 0 19 | if not dist.is_initialized(): 20 | return 0 21 | return dist.get_rank() 22 | 23 | 24 | def is_main_process(): 25 | return get_rank() == 0 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | """ 30 | Returns: 31 | True if distributed training is enabled 32 | """ 33 | if not dist.is_available(): 34 | return False 35 | if not dist.is_initialized(): 36 | return False 37 | return True 38 | 39 | 40 | def synchronize(): 41 | """ 42 | Helper function to synchronize (barrier) among all processes when 43 | using distributed training 44 | """ 45 | if not dist.is_available(): 46 | return 47 | if not dist.is_initialized(): 48 | return 49 | world_size = dist.get_world_size() 50 | if world_size == 1: 51 | return 52 | dist.barrier() 53 | 54 | 55 | def all_gather(data): 56 | """ 57 | Run all_gather on arbitrary picklable data (not necessarily tensors) 58 | Args: 59 | data: any picklable object 60 | Returns: 61 | list[data]: list of data gathered from each rank 62 | """ 63 | to_device = "cuda" 64 | #to_device = torch.device("cpu") 65 | 66 | world_size = get_world_size() 67 | if world_size == 1: 68 | return [data] 69 | 70 | # serialized to a Tensor 71 | buffer = pickle.dumps(data) 72 | storage = torch.ByteStorage.from_buffer(buffer) 73 | tensor = torch.ByteTensor(storage).to(to_device) 74 | 75 | # obtain Tensor size of each rank 76 | local_size = torch.LongTensor([tensor.numel()]).to(to_device) 77 | size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] 78 | dist.all_gather(size_list, local_size) 79 | size_list = [int(size.item()) for size in size_list] 80 | max_size = max(size_list) 81 | 82 | # receiving Tensor from all ranks 83 | # we pad the tensor because torch all_gather does not support 84 | # gathering tensors of different shapes 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to(to_device)) 88 | if local_size != max_size: 89 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) 90 | tensor = torch.cat((tensor, padding), dim=0) 91 | dist.all_gather(tensor_list, tensor) 92 | 93 | data_list = [] 94 | for size, tensor in zip(size_list, tensor_list): 95 | buffer = tensor.cpu().numpy().tobytes()[:size] 96 | data_list.append(pickle.loads(buffer)) 97 | 98 | return data_list 99 | 100 | 101 | def reduce_dict(input_dict, average=True): 102 | """ 103 | Args: 104 | input_dict (dict): all the values will be reduced 105 | average (bool): whether to do average or sum 106 | Reduce the values in the dictionary from all processes so that process with rank 107 | 0 has the averaged results. Returns a dict with the same fields as 108 | input_dict, after reduction. 109 | """ 110 | world_size = get_world_size() 111 | if world_size < 2: 112 | return input_dict 113 | with torch.no_grad(): 114 | names = [] 115 | values = [] 116 | # sort the keys so that they are consistent across processes 117 | for k in sorted(input_dict.keys()): 118 | names.append(k) 119 | values.append(input_dict[k]) 120 | values = torch.stack(values, dim=0) 121 | dist.reduce(values, dst=0) 122 | if get_rank() == 0 and average: 123 | # only main process gets accumulated, so only divide by 124 | # world_size in this case 125 | values /= world_size 126 | reduced_dict = {k: v for k, v in zip(names, values)} 127 | return reduced_dict 128 | 129 | 130 | def reduce_loss_dict(loss_dict): 131 | """ 132 | Reduce the loss dictionary from all processes so that process with rank 133 | 0 has the averaged results. Returns a dict with the same fields as 134 | loss_dict, after reduction. 135 | """ 136 | world_size = get_world_size() 137 | if world_size < 2: 138 | return loss_dict 139 | with torch.no_grad(): 140 | loss_names = [] 141 | all_losses = [] 142 | for k in sorted(loss_dict.keys()): 143 | loss_names.append(k) 144 | all_losses.append(loss_dict[k]) 145 | all_losses = torch.stack(all_losses, dim=0) 146 | dist.reduce(all_losses, dst=0) 147 | if get_rank() == 0: 148 | # only main process gets accumulated, so only divide by 149 | # world_size in this case 150 | all_losses /= world_size 151 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 152 | return reduced_losses -------------------------------------------------------------------------------- /datasets/words.py: -------------------------------------------------------------------------------- 1 | replace_dict = { 2 | 'blacj' : 'black', 3 | 'plastci' : 'plastic', 4 | 'actmst' : '', 5 | 'smll' : 'small', 6 | 'cothes' : 'clothes', 7 | 'ywllow' : 'yellow', 8 | 'yelow' : 'yellow', 9 | 'awhite' : 'a white', 10 | 'halmat' : 'helmet', 11 | 'barball' : 'barbell', 12 | 'palid' : '', 13 | 'livig' : 'living', 14 | 'inwhite' : 'in white', 15 | 'nissthe' : 'nissan', 16 | 'jrans' : 'jeans', 17 | 'hwite' : 'white', 18 | 'softhe' : 'sofa', 19 | 'tabble' : 'table', 20 | 'bige' : 'big', 21 | 'speakin' : 'speaking', 22 | 'waering' : 'wearing', 23 | 'hotal' : 'hotel', 24 | 'playgrond' : 'playground', 25 | 'dimgrey' : 'gray', 26 | 'trowards' : 'towards', 27 | 'yelllow' : 'yellow', 28 | 'bowns' : 'bown', 29 | 'outsoors' : 'outdoors', 30 | 'resturant' : 'restaurant', 31 | 'coloe' : 'color', 32 | 'fatest' : 'fat', 33 | 'classrooom' : 'classroom', 34 | 'wahite' : 'white', 35 | 'bkini' : 'bikini', 36 | 'andult' : 'adult', 37 | 'woaman' : 'woman', 38 | 'touchs' : 'touch', 39 | 'adutl' : 'adult', 40 | 'palyground' : 'playground', 41 | 'ppurple' : 'purple', 42 | 'stairscase' : 'staircase', 43 | 'sungalsses' : 'sunglasses', 44 | 'inblack' : 'in black', 45 | 'abovce' : 'above', 46 | 'evenging' : 'evening', 47 | 'ourdoors' : 'outdoors', 48 | 'ocethe' : 'ocean', 49 | 'glaasses' : 'glasses', 50 | 'woamn' : 'woman', 51 | 'fmale' : 'female', 52 | 'withsunglasses' : 'with sunglasses', 53 | 'gloden' : 'golden', 54 | 'straint' : 'straight', 55 | 'grabing' : 'grabbing', 56 | 'sittingabove' : 'sitting above', 57 | 'famle' : 'female', 58 | 'childern' : 'children', 59 | 'baby_seat' : 'baby seat', 60 | 'inin' : 'in', 61 | 'waer' : 'water', 62 | 'womthe' : 'woman', 63 | 'hoome' : 'home', 64 | 'tiget' : 'tiger', 65 | 'mthe' : 'man', 66 | 'galsses' : 'glasses', 67 | 'abvoe' : 'above', 68 | 'wristhand' : 'wristband', 69 | 'get_off' : 'get off', 70 | 'thebed' : 'the bed', 71 | 'halmet' : 'helmet', 72 | 'theroom' : 'the room', 73 | 'bibycle' : 'bicycle', 74 | 'peachpuff' : 'pink', 75 | 'cythe' : 'cyan', 76 | 'mountarn' : 'mountain', 77 | 'chidl' : 'child', 78 | 'ththe' : 'the', 79 | 'yeloow' : 'yellow', 80 | 'iscaress' : 'is caress', 81 | 'thesofa' : 'the sofa', 82 | 'surboard' : 'surfboard', 83 | 'wearig' : 'wearing', 84 | 'blone' : 'blonde', 85 | 'watche' : 'watch', 86 | 'inisde' : 'inside', 87 | 'wman' : 'woman', 88 | 'eatting' : 'eating', 89 | 'colorfuls' : 'colorful', 90 | 'whhite' : 'white', 91 | 'playgrouns' : 'playground', 92 | 'qhite' : 'white', 93 | 'roomm' : 'room', 94 | 'watchs' : 'watches', 95 | 'woodem' : 'wooden', 96 | 'insdie' : 'inside', 97 | 'whtie' : 'white', 98 | 'colth' : 'clothes', 99 | 'newbron' : 'newborn', 100 | 'sittint' : 'sitting', 101 | 'colorfu' : 'colorful', 102 | 'barthroom' : 'bathroom', 103 | 'claybank' : 'brown', 104 | '1another' : 'another', 105 | 'clorful' : 'colorful', 106 | 'blggest' : 'biggest', 107 | 'photoing' : 'photo', 108 | 'blck' : 'black', 109 | 'clthes' : 'clothes', 110 | 'insidethe' : 'inside the', 111 | 'woma' : 'woman', 112 | 'colthes' : 'clothes', 113 | 'pnik' : 'pink', 114 | 'torwards' : 'towards', 115 | 'aborad' : 'aboard', 116 | 'throwes' : 'throws', 117 | 'varrying' : 'varying', 118 | 'wathet' : 'blue', 119 | 'withfew' : 'with few', 120 | 'blcak' : 'black', 121 | 'adule' : 'adult', 122 | 'clotehs' : 'clothes', 123 | 'onth' : 'on the', 124 | 'coloful' : 'colorful', 125 | 'inred' : 'in red', 126 | 'clohtes' : 'clothes', 127 | 'scoks' : 'socks', 128 | 'carrys' : 'carry', 129 | 'ground1' : 'ground', 130 | 'pandthe' : 'panda', 131 | 'wwearing' : 'wearing', 132 | 'trouers' : 'trousers', 133 | 'babyseat' : 'baby seat', 134 | 'meetingplace' : 'meeting place', 135 | 'tellow' : 'yellow', 136 | 'mwn' : 'man', 137 | 'holdiung' : 'holding', 138 | 'woodens' : 'wooden', 139 | 'stop_sign' : 'stop sign', 140 | 'palegodenrod' : 'yellow', 141 | 'putple' : 'purple', 142 | 'waveing' : 'waving', 143 | 'theshow' : 'the show', 144 | 'whiet' : 'white', 145 | 'audlt' : 'adult', 146 | 'borwn' : 'brown', 147 | 'besidethe' : 'beside the', 148 | 'hulmet' : 'helmet', 149 | 'next_to' : 'next to', 150 | 'thegrass' : 'the grass', 151 | 'chaqueta' : 'jacket', 152 | 'smmall' : 'small', 153 | 'geay' : 'gray', 154 | 'woemen' : 'woman', 155 | 'grya' : 'gray', 156 | 'othere' : 'other', 157 | 'brwon' : 'brown', 158 | 'babt' : 'baby', 159 | 'anothe' : 'another', 160 | 'swmming' : 'swimming', 161 | 'waeing' : 'wearing', 162 | 'watarfall' : 'waterfall', 163 | 'weddding' : 'wedding', 164 | 'drowm' : 'drown', 165 | 'kiechen' : 'kitchen', 166 | 'secene' : 'scene', 167 | 'puple' : 'purple', 168 | 'straid' : 'striped' 169 | } -------------------------------------------------------------------------------- /datasets/samplers/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | class GroupedBatchSampler(BatchSampler): 10 | """ 11 | Wraps another sampler to yield a mini-batch of indices. 12 | It enforces that elements from the same group should appear in groups of batch_size. 13 | It also tries to provide mini-batches which follows an ordering which is 14 | as close as possible to the ordering from the original sampler. 15 | 16 | Arguments: 17 | sampler (Sampler): Base sampler. 18 | batch_size (int): Size of mini-batch. 19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 20 | size is less than ``batch_size`` 21 | 22 | """ 23 | 24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): 25 | if not isinstance(sampler, Sampler): 26 | raise ValueError( 27 | "sampler should be an instance of " 28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 29 | ) 30 | self.sampler = sampler 31 | self.group_ids = torch.as_tensor(group_ids) 32 | assert self.group_ids.dim() == 1 33 | self.batch_size = batch_size 34 | self.drop_uneven = drop_uneven 35 | 36 | self.groups = torch.unique(self.group_ids).sort(0)[0] 37 | self._can_reuse_batches = False 38 | 39 | def _prepare_batches(self): 40 | dataset_size = len(self.group_ids) 41 | # get the sampled indices from the sampler 42 | sampled_ids = torch.as_tensor(list(self.sampler)) 43 | # potentially not all elements of the dataset were sampled 44 | # by the sampler (e.g., DistributedSampler). 45 | # construct a tensor which contains -1 if the element was 46 | # not sampled, and a non-negative number indicating the 47 | # order where the element was sampled. 48 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 49 | # the order is [-1, 1, -1, 0, -1] 50 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 51 | order[sampled_ids] = torch.arange(len(sampled_ids)) 52 | 53 | # get a mask with the elements that were sampled 54 | mask = order >= 0 55 | 56 | # find the elements that belong to each individual cluster 57 | clusters = [(self.group_ids == i) & mask for i in self.groups] 58 | # get relative order of the elements inside each cluster 59 | # that follows the order from the sampler 60 | relative_order = [order[cluster] for cluster in clusters] 61 | # with the relative order, find the absolute order in the 62 | # sampled space 63 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 64 | # permute each cluster so that they follow the order from 65 | # the sampler 66 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 67 | 68 | # splits each cluster in batch_size, and merge as a list of tensors 69 | splits = [c.split(self.batch_size) for c in permuted_clusters] 70 | merged = tuple(itertools.chain.from_iterable(splits)) 71 | 72 | # now each batch internally has the right order, but 73 | # they are grouped by clusters. Find the permutation between 74 | # different batches that brings them as close as possible to 75 | # the order that we have in the sampler. For that, we will consider the 76 | # ordering as coming from the first element of each batch, and sort 77 | # correspondingly 78 | first_element_of_batch = [t[0].item() for t in merged] 79 | # get and inverse mapping from sampled indices and the position where 80 | # they occur (as returned by the sampler) 81 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 82 | # from the first element in each batch, get a relative ordering 83 | first_index_of_batch = torch.as_tensor( 84 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 85 | ) 86 | 87 | # permute the batches so that they approximately follow the order 88 | # from the sampler 89 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 90 | # finally, permute the batches 91 | batches = [merged[i].tolist() for i in permutation_order] 92 | 93 | if self.drop_uneven: 94 | kept = [] 95 | for batch in batches: 96 | if len(batch) == self.batch_size: 97 | kept.append(batch) 98 | batches = kept 99 | return batches 100 | 101 | def __iter__(self): 102 | if self._can_reuse_batches: 103 | batches = self._batches 104 | self._can_reuse_batches = False 105 | else: 106 | batches = self._prepare_batches() 107 | self._batches = batches 108 | return iter(batches) 109 | 110 | def __len__(self): 111 | if not hasattr(self, "_batches"): 112 | self._batches = self._prepare_batches() 113 | self._can_reuse_batches = True 114 | return len(self._batches) 115 | -------------------------------------------------------------------------------- /models/vision_model/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | class PositionEmbeddingSineHW(nn.Module): 7 | """ 8 | This is a more standard version of the position embedding, very similar to the one 9 | used by the Attention is all you need paper, generalized to work on images. 10 | """ 11 | def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None): 12 | super().__init__() 13 | self.num_pos_feats = num_pos_feats 14 | self.temperatureH = temperatureH 15 | self.temperatureW = temperatureW 16 | self.normalize = normalize 17 | if scale is not None and normalize is False: 18 | raise ValueError("normalize should be True if scale is passed") 19 | if scale is None: 20 | scale = 2 * math.pi 21 | self.scale = scale 22 | 23 | def forward(self, tensor_list): 24 | x = tensor_list.tensors 25 | mask = tensor_list.mask 26 | assert mask is not None 27 | not_mask = ~mask 28 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 29 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 30 | 31 | if self.normalize: 32 | eps = 1e-6 33 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 34 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 35 | 36 | dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 37 | dim_tx = self.temperatureW ** (2 * torch.div(dim_tx, 2, rounding_mode='floor') / self.num_pos_feats) 38 | pos_x = x_embed[:, :, :, None] / dim_tx 39 | 40 | dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_ty = self.temperatureH ** (2 * torch.div(dim_ty, 2, rounding_mode='floor') / self.num_pos_feats) 42 | pos_y = y_embed[:, :, :, None] / dim_ty 43 | 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | 48 | return pos 49 | 50 | 51 | class PositionEmbeddingSine(nn.Module): 52 | """ 53 | This is a more standard version of the position embedding, very similar to the one 54 | used by the Attention is all you need paper, generalized to work on images. 55 | """ 56 | 57 | def __init__( 58 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None 59 | ): 60 | super().__init__() 61 | self.num_pos_feats = num_pos_feats 62 | self.temperature = temperature 63 | self.normalize = normalize 64 | if scale is not None and normalize is False: 65 | raise ValueError("normalize should be True if scale is passed") 66 | if scale is None: 67 | scale = 2 * math.pi 68 | self.scale = scale 69 | 70 | def forward(self, tensor_list): 71 | x = tensor_list.tensors 72 | mask = tensor_list.mask 73 | not_mask = ~mask 74 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 75 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 76 | if self.normalize: 77 | eps = 1e-6 78 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 79 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 80 | 81 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 82 | # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 83 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.num_pos_feats) 84 | 85 | pos_x = x_embed[:, :, :, None] / dim_t 86 | pos_y = y_embed[:, :, :, None] / dim_t 87 | pos_x = torch.stack( 88 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 89 | ).flatten(3) 90 | pos_y = torch.stack( 91 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 92 | ).flatten(3) 93 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 94 | return pos 95 | 96 | 97 | class PositionEmbeddingLearned(nn.Module): 98 | """ 99 | Absolute pos embedding, learned. 100 | """ 101 | 102 | def __init__(self, num_pos_feats=256): 103 | super().__init__() 104 | self.row_embed = nn.Embedding(50, num_pos_feats) 105 | self.col_embed = nn.Embedding(50, num_pos_feats) 106 | self.reset_parameters() 107 | 108 | def reset_parameters(self): 109 | nn.init.uniform_(self.row_embed.weight) 110 | nn.init.uniform_(self.col_embed.weight) 111 | 112 | def forward(self, tensor_list): 113 | x = tensor_list.tensors 114 | h, w = x.shape[-2:] 115 | i = torch.arange(w, device=x.device) 116 | j = torch.arange(h, device=x.device) 117 | x_emb = self.col_embed(i) 118 | y_emb = self.row_embed(j) 119 | pos = ( 120 | torch.cat( 121 | [ 122 | x_emb.unsqueeze(0).repeat(h, 1, 1), 123 | y_emb.unsqueeze(1).repeat(1, w, 1), 124 | ], 125 | dim=-1, 126 | ) 127 | .permute(2, 0, 1) 128 | .unsqueeze(0) 129 | .repeat(x.shape[0], 1, 1, 1) 130 | ) 131 | return pos 132 | 133 | 134 | def build_position_encoding(cfg): 135 | N_steps = cfg.MODEL.CG.HIDDEN // 2 136 | encode_type = cfg.MODEL.VISION_BACKBONE.POS_ENC 137 | if encode_type == "sine": 138 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 139 | elif encode_type == "sineHW": 140 | position_embedding = PositionEmbeddingSineHW(N_steps, 20, 20, normalize=True) 141 | elif encode_type == "learned": 142 | position_embedding = PositionEmbeddingLearned(N_steps) 143 | else: 144 | raise ValueError(f"not supported {encode_type}") 145 | 146 | return position_embedding 147 | -------------------------------------------------------------------------------- /models/vision_model/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | Backbone modules. 5 | """ 6 | from collections import OrderedDict 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | 13 | from utils.misc import NestedTensor 14 | 15 | 16 | class FrozenBatchNorm2d(torch.nn.Module): 17 | """ 18 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 19 | 20 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 21 | without which any other models than torchvision.models.resnet[18,34,50,101] 22 | produce nans. 23 | """ 24 | 25 | def __init__(self, n): 26 | super(FrozenBatchNorm2d, self).__init__() 27 | self.register_buffer("weight", torch.ones(n)) 28 | self.register_buffer("bias", torch.zeros(n)) 29 | self.register_buffer("running_mean", torch.zeros(n)) 30 | self.register_buffer("running_var", torch.ones(n)) 31 | 32 | def _load_from_state_dict( 33 | self, 34 | state_dict, 35 | prefix, 36 | local_metadata, 37 | strict, 38 | missing_keys, 39 | unexpected_keys, 40 | error_msgs, 41 | ): 42 | num_batches_tracked_key = prefix + "num_batches_tracked" 43 | if num_batches_tracked_key in state_dict: 44 | del state_dict[num_batches_tracked_key] 45 | 46 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 47 | state_dict, 48 | prefix, 49 | local_metadata, 50 | strict, 51 | missing_keys, 52 | unexpected_keys, 53 | error_msgs, 54 | ) 55 | 56 | def forward(self, x): 57 | # move reshapes to the beginning 58 | # to make it fuser-friendly 59 | w = self.weight.reshape(1, -1, 1, 1) 60 | b = self.bias.reshape(1, -1, 1, 1) 61 | rv = self.running_var.reshape(1, -1, 1, 1) 62 | rm = self.running_mean.reshape(1, -1, 1, 1) 63 | eps = 1e-5 64 | scale = w * (rv + eps).rsqrt() 65 | bias = b - rm * scale 66 | return x * scale + bias 67 | 68 | 69 | class BackboneBase(nn.Module): 70 | def __init__( 71 | self, 72 | backbone: nn.Module, 73 | train_backbone: bool, 74 | num_channels: int, 75 | return_interm_layers: bool, 76 | ): 77 | super().__init__() 78 | for name, parameter in backbone.named_parameters(): 79 | if ( 80 | not train_backbone 81 | or "layer2" not in name 82 | and "layer3" not in name 83 | and "layer4" not in name 84 | ): 85 | parameter.requires_grad_(False) 86 | if return_interm_layers: 87 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 88 | else: 89 | return_layers = {"layer4": 0} 90 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 91 | self.num_channels = num_channels 92 | 93 | def forward(self, tensor_list): 94 | durations = tensor_list.durations 95 | xs = self.body(tensor_list.tensors) 96 | out = OrderedDict() 97 | for name, x in xs.items(): 98 | m = tensor_list.mask 99 | assert m is not None 100 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 101 | out[name] = NestedTensor(x, mask, durations) 102 | return out 103 | 104 | 105 | class Backbone(BackboneBase): 106 | """ResNet backbone with frozen BatchNorm.""" 107 | 108 | def __init__( 109 | self, 110 | name: str, 111 | train_backbone: bool, 112 | return_interm_layers: bool, 113 | dilation: bool, 114 | ): 115 | backbone = getattr(torchvision.models, name)( 116 | replace_stride_with_dilation=[False, False, dilation], 117 | pretrained=True, 118 | norm_layer=FrozenBatchNorm2d, 119 | ) 120 | num_channels = 512 if name in ("resnet18", "resnet34") else 2048 121 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 122 | 123 | 124 | class GroupNorm32(torch.nn.GroupNorm): 125 | def __init__(self, num_channels, num_groups=32, **kargs): 126 | super().__init__(num_groups, num_channels, **kargs) 127 | 128 | 129 | class GroupNormBackbone(BackboneBase): 130 | """ResNet backbone with GroupNorm with 32 channels.""" 131 | 132 | def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): 133 | name_map = { 134 | "resnet50-gn": ("resnet50", "/checkpoint/szagoruyko/imagenet/22014122/checkpoint.pth"), 135 | "resnet101-gn": ("resnet101", "/checkpoint/szagoruyko/imagenet/22080524/checkpoint.pth"), 136 | } 137 | backbone = getattr(torchvision.models, name_map[name][0])( 138 | replace_stride_with_dilation=[False, False, dilation], pretrained=False, norm_layer=GroupNorm32 139 | ) 140 | checkpoint = torch.load(name_map[name][1], map_location="cpu") 141 | state_dict = {k[7:]: p for k, p in checkpoint["model"].items()} 142 | backbone.load_state_dict(state_dict) 143 | num_channels = 512 if name_map[name][0] in ("resnet18", "resnet34") else 2048 144 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 145 | 146 | 147 | class Joiner(nn.Sequential): 148 | def __init__(self, backbone, position_embedding): 149 | super().__init__(backbone, position_embedding) 150 | 151 | def forward(self, tensor_list): 152 | xs = self[0](tensor_list) 153 | out = [] 154 | pos = [] 155 | for name, x in xs.items(): 156 | out.append(x) 157 | pos.append(self[1](x).to(x.tensors.dtype)) 158 | 159 | return out[-1], pos[-1] 160 | -------------------------------------------------------------------------------- /engine/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from typing import Dict 4 | 5 | from utils.misc import to_device 6 | from utils.comm import synchronize, is_main_process 7 | from tqdm import tqdm 8 | 9 | 10 | @torch.no_grad() 11 | def linear_interp(bbox_dict): 12 | frame_ids = sorted([fid for fid in bbox_dict]) 13 | if len(frame_ids) < 2: 14 | return bbox_dict 15 | for idx in range(0, len(frame_ids) - 1): 16 | left_fid = frame_ids[idx] 17 | right_fid = frame_ids[idx + 1] 18 | if right_fid - left_fid > 1: 19 | interval = right_fid - left_fid 20 | delta_x1 = (bbox_dict[right_fid][0][0] - bbox_dict[left_fid][0][0]) / interval 21 | delta_y1 = (bbox_dict[right_fid][0][1] - bbox_dict[left_fid][0][1]) / interval 22 | delta_x2 = (bbox_dict[right_fid][0][2] - bbox_dict[left_fid][0][2]) / interval 23 | delta_y2 = (bbox_dict[right_fid][0][3] - bbox_dict[left_fid][0][3]) / interval 24 | for step in range(1, interval): 25 | bbox_dict[left_fid + step] = [[ 26 | bbox_dict[left_fid][0][0] + step * delta_x1, 27 | bbox_dict[left_fid][0][1] + step * delta_y1, 28 | bbox_dict[left_fid][0][2] + step * delta_x2, 29 | bbox_dict[left_fid][0][3] + step * delta_y2, 30 | ]] 31 | 32 | frame_ids = sorted([fid for fid in bbox_dict]) 33 | assert max(frame_ids) - min(frame_ids) + 1 == len(frame_ids) 34 | return {fid : bbox_dict[fid] for fid in frame_ids} 35 | 36 | @torch.no_grad() 37 | def single_forward(cfg, model, videos, texts, targets, device, postprocessor): 38 | durations = videos.durations 39 | targets[0]["durations"] = durations 40 | outputs = model(videos, texts, targets) 41 | 42 | b = len(durations) 43 | t = max(durations) 44 | batch_img_size = [list(target['ori_size']) for target in targets] 45 | orig_target_sizes = [img_size for img_size in batch_img_size for _ in range(t)] 46 | orig_target_sizes = torch.tensor(orig_target_sizes,device=device) 47 | assert orig_target_sizes.shape[0] == outputs['pred_boxes'].shape[0] 48 | 49 | frames_ids = [target['frame_ids'] for target in targets] 50 | pred_boxs, pred_steds = postprocessor(outputs, orig_target_sizes, frames_ids, durations) 51 | pred_boxs = pred_boxs.view(b, t, 4) 52 | 53 | vids = [target['item_id'] for target in targets] 54 | bbox_pred, temp_pred = {}, {} 55 | 56 | for i_b in range(b): 57 | frames_id = frames_ids[i_b] 58 | bbox_pred[vids[i_b]] = {} 59 | assert durations[i_b] == len(frames_id) 60 | for idx in range(durations[i_b]): 61 | bbox_pred[vids[i_b]][frames_id[idx]] = [pred_boxs[i_b][idx].detach().cpu().tolist()] 62 | 63 | if cfg.DATASET.NAME == 'VidSTG': 64 | qtypes = [target['qtype'] for target in targets] 65 | assert len(pred_steds) == len(qtypes) 66 | for i_b in range(b): 67 | temp_pred[vids[i_b]] = { 68 | "sted": pred_steds[i_b], 69 | "qtype": qtypes[i_b], 70 | } 71 | else: 72 | for i_b in range(b): 73 | temp_pred[vids[i_b]] = { 74 | "sted": pred_steds[i_b] 75 | } 76 | 77 | return bbox_pred, temp_pred 78 | 79 | 80 | @torch.no_grad() 81 | def do_eval(cfg, mode, logger, model, postprocessor, data_loader, evaluator, device): 82 | """ 83 | Video Spatial-Temporal Grounding Evaluation 84 | """ 85 | model.eval() 86 | logger.info("Start evaluation on the {} split of {} dataset".format(mode, cfg.DATASET.NAME)) 87 | 88 | for _, batch_dict in enumerate(tqdm(data_loader)): 89 | videos = batch_dict['videos'].to(device) 90 | texts = batch_dict['texts'] 91 | targets = to_device(batch_dict["targets"], device) 92 | 93 | for i in range(len(targets)): 94 | if 'qtype' not in targets[i]: 95 | targets[i]['qtype'] = 'none' 96 | 97 | videos1 = videos.subsample(2, start_idx=0) 98 | targets1 = [{'item_id': target['item_id'], 'ori_size': target['ori_size'], 99 | 'qtype': target['qtype'], 'frame_ids': target['frame_ids'][0::2], "boxs":target['boxs'].bbox.clone(), 'actioness':target['actioness'][0::2], "eval":True} for target in targets] 100 | 101 | videos2 = videos.subsample(2, start_idx=1) 102 | targets2 = [{'item_id': target['item_id'], 'ori_size': target['ori_size'], 103 | 'qtype': target['qtype'], 'frame_ids': target['frame_ids'][1::2], "boxs":target['boxs'].bbox.clone(), 'actioness':target['actioness'][1::2], "eval":True} for target in targets] 104 | 105 | if torch.where(targets[0]["actioness"])[0][0] % 2 == 0: 106 | targets1[0]['boxs'] = targets1[0]['boxs'][0::2] 107 | targets2[0]['boxs'] = targets2[0]['boxs'][1::2] 108 | else: 109 | targets1[0]['boxs'] = targets1[0]['boxs'][1::2] 110 | targets2[0]['boxs'] = targets2[0]['boxs'][0::2] 111 | 112 | bbox_pred1, temp_pred1 = single_forward(cfg, model, videos1, texts, 113 | targets1, device, postprocessor) 114 | bbox_pred2, temp_pred2 = single_forward(cfg, model, videos2, texts, 115 | targets2, device, postprocessor) 116 | 117 | bbox_pred, temp_pred = {}, {} 118 | for vid in bbox_pred1: 119 | bbox_pred1[vid].update(bbox_pred2[vid]) 120 | bbox_pred[vid] = linear_interp(bbox_pred1[vid]) 121 | temp_pred[vid] = {'sted' : [min(temp_pred1[vid]['sted'][0], temp_pred2[vid]['sted'][0]), 122 | max(temp_pred1[vid]['sted'][1], temp_pred2[vid]['sted'][1])]} 123 | if 'qtype' in temp_pred1[vid]: 124 | temp_pred[vid]['qtype'] = temp_pred1[vid]['qtype'] 125 | 126 | evaluator.update(bbox_pred) 127 | evaluator.video_update(temp_pred) 128 | 129 | synchronize() 130 | evaluator.synchronize_between_processes() 131 | if is_main_process(): 132 | logger.info(f"Complete the inference on {mode} split of {cfg.DATASET.NAME}") 133 | 134 | res = evaluator.summarize() 135 | return res -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import bisect 4 | import copy 5 | 6 | 7 | import torch 8 | import torch.utils.data 9 | from utils.comm import get_world_size 10 | 11 | from torch.utils.data import DistributedSampler 12 | 13 | from . import samplers 14 | from . import transforms as T 15 | from .vidstg import VidSTGDataset 16 | from .hcstvg import HCSTVGDataset 17 | from .collate_batch import collate_fn 18 | 19 | 20 | def build_transforms(cfg, is_train=True): 21 | imsize = cfg.INPUT.RESOLUTION 22 | max_size = 720 23 | if is_train: 24 | flip_horizontal_prob = cfg.INPUT.FLIP_PROB_TRAIN 25 | 26 | scales = [] 27 | if cfg.INPUT.AUG_SCALE: 28 | for i in range(4): 29 | scales.append(imsize - 32 * i) 30 | else: 31 | scales = [imsize] 32 | 33 | transform = T.Compose( 34 | [ 35 | T.RandomHorizontalFlip(flip_horizontal_prob), 36 | T.RandomSelect( 37 | T.RandomResize(scales, max_size=max_size), 38 | T.Compose( 39 | [ 40 | T.RandomResize([400, 500, 600]), 41 | T.RandomSizeCrop(384, 600), 42 | T.RandomResize(scales, max_size=max_size), 43 | ] 44 | ), 45 | ), 46 | T.Normalize( 47 | mean=cfg.INPUT.PIXEL_MEAN, 48 | std=cfg.INPUT.PIXEL_STD 49 | ), 50 | ] 51 | ) 52 | 53 | else: 54 | transform = T.Compose( 55 | [ 56 | T.RandomResize(imsize, max_size=max_size), 57 | T.Normalize( 58 | mean=cfg.INPUT.PIXEL_MEAN, 59 | std=cfg.INPUT.PIXEL_STD 60 | ), 61 | ] 62 | ) 63 | 64 | return transform 65 | 66 | 67 | def build_dataset(cfg, split, transforms): 68 | dataset_name = cfg.DATASET.NAME 69 | if dataset_name == 'VidSTG': 70 | return VidSTGDataset( 71 | cfg, 72 | split, 73 | transforms 74 | ) 75 | elif dataset_name == 'HC-STVG': 76 | return HCSTVGDataset( 77 | cfg, 78 | split, 79 | transforms 80 | ) 81 | else: 82 | raise ValueError("{} is not Supported".format(dataset_name)) 83 | 84 | 85 | def make_data_sampler(dataset, shuffle, distributed): 86 | if distributed: 87 | return DistributedSampler(dataset, shuffle=shuffle) 88 | if shuffle: 89 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 90 | else: 91 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 92 | return sampler 93 | 94 | 95 | def _quantize(x, bins): 96 | bins = copy.copy(bins) 97 | bins = sorted(bins) 98 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 99 | return quantized 100 | 101 | 102 | def _compute_aspect_ratios(dataset): 103 | aspect_ratios = [] 104 | for i in range(len(dataset)): 105 | video_info = dataset.get_video_info(i) 106 | aspect_ratio = float(video_info["height"]) / float(video_info["width"]) 107 | aspect_ratios.append(aspect_ratio) 108 | return aspect_ratios 109 | 110 | 111 | def _count_frame_size(dataset): 112 | img_sizes = dict() 113 | for i in range(len(dataset)): 114 | video_info = dataset.get_video_info(i) 115 | img_sizes.setdefault((video_info['width'],video_info['height']),0) 116 | img_sizes[(video_info['width'],video_info['height'])] += 1 117 | 118 | 119 | def make_batch_data_sampler( 120 | dataset, sampler, aspect_grouping, batch_size, num_iters=None, start_iter=0, is_train=True 121 | ): 122 | if aspect_grouping: 123 | if not isinstance(aspect_grouping, (list, tuple)): 124 | aspect_grouping = [aspect_grouping] 125 | aspect_ratios = _compute_aspect_ratios(dataset) 126 | group_ids = _quantize(aspect_ratios, aspect_grouping) 127 | batch_sampler = samplers.GroupedBatchSampler( 128 | sampler, group_ids, batch_size, drop_uneven=False 129 | ) 130 | else: 131 | batch_sampler = torch.utils.data.sampler.BatchSampler( 132 | sampler, batch_size, drop_last=True if is_train else False 133 | ) 134 | if num_iters is not None: 135 | batch_sampler = samplers.IterationBasedBatchSampler( 136 | batch_sampler, num_iters, start_iter 137 | ) 138 | return batch_sampler 139 | 140 | 141 | def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0): 142 | assert mode in {'train', 'val', 'test'} 143 | num_gpus = get_world_size() 144 | is_train = mode == 'train' 145 | 146 | transforms = build_transforms(cfg, is_train) 147 | dataset = build_dataset(cfg, mode, transforms) 148 | 149 | if is_train: 150 | videos_per_batch = cfg.SOLVER.BATCH_SIZE * num_gpus 151 | assert cfg.SOLVER.BATCH_SIZE == 1, "Each GPU should only take 1 video." 152 | videos_per_gpu = cfg.SOLVER.BATCH_SIZE 153 | shuffle = True 154 | num_epochs = cfg.SOLVER.MAX_EPOCH 155 | num_iters = num_epochs * math.ceil(len(dataset) / videos_per_batch) 156 | else: 157 | assert cfg.SOLVER.BATCH_SIZE == 1, "Each GPU should only take 1 video." 158 | videos_per_gpu = cfg.SOLVER.BATCH_SIZE 159 | shuffle = False 160 | num_iters = None 161 | start_iter = 0 162 | 163 | # group videos which have similar aspect ratio. In this case, we only 164 | # group in two cases: those with width / height > 1, and the other way around, 165 | # but the code supports more general grouping strategy 166 | aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] 167 | 168 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 169 | batch_sampler = make_batch_data_sampler( 170 | dataset, sampler, aspect_grouping, videos_per_gpu, num_iters, start_iter, is_train=is_train 171 | ) 172 | num_workers = cfg.DATALOADER.NUM_WORKERS 173 | 174 | data_loader = torch.utils.data.DataLoader( 175 | dataset, 176 | num_workers=num_workers, 177 | batch_sampler=batch_sampler, 178 | collate_fn=collate_fn, 179 | ) 180 | 181 | return data_loader -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Config definition 5 | # ----------------------------------------------------------------------------- 6 | _C = CN() 7 | _C.FROM_SCRATCH = True 8 | _C.DATA_TRUNK = None 9 | 10 | _C.OUTPUT_DIR = '' 11 | _C.DATA_DIR = '' 12 | _C.GLOVE_DIR = '' 13 | _C.TENSORBOARD_DIR = '' 14 | 15 | 16 | # ----------------------------------------------------------------------------- 17 | # INPUT 18 | # ----------------------------------------------------------------------------- 19 | _C.INPUT = CN() 20 | _C.INPUT.MAX_QUERY_LEN = 26 21 | _C.INPUT.MAX_VIDEO_LEN = 200 22 | 23 | # The input frame number (For VidSTG) 24 | _C.INPUT.TRAIN_SAMPLE_NUM = 64 25 | # The input frame rate (For HC_STVG, 20s per input video) 26 | _C.INPUT.SAMPLE_FPS = 3.2 27 | 28 | # The input video resolution 29 | _C.INPUT.RESOLUTION = 224 30 | # Values to be used for image normalization 31 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 32 | # Values to be used for image normalization 33 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 34 | # If perform multiscale training 35 | _C.INPUT.AUG_SCALE = True 36 | # If perform translate augumentation training 37 | _C.INPUT.AUG_TRANSLATE = False 38 | 39 | # Image ColorJitter 40 | _C.INPUT.FLIP_PROB_TRAIN = 0.5 41 | _C.INPUT.TEMP_CROP_PROB = 0.5 42 | 43 | # ----------------------------------------------------------------------------- 44 | # Model Config 45 | # ----------------------------------------------------------------------------- 46 | _C.MODEL = CN() 47 | _C.MODEL.DEVICE = "cuda" 48 | _C.MODEL.WEIGHT = "" 49 | _C.MODEL.EMA = True 50 | _C.MODEL.EMA_DECAY = 0.9998 51 | _C.MODEL.QUERY_NUM = 1 # each frame a single query 52 | _C.MODEL.DOWN_RATIO = 4 53 | 54 | # ----------------------------------------------------------------------------- 55 | # Vision Encoder options 56 | # ----------------------------------------------------------------------------- 57 | 58 | _C.MODEL.VISION_BACKBONE = CN() 59 | _C.MODEL.VISION_BACKBONE.NAME = 'resnet101' # resnet50 or resnet101 60 | _C.MODEL.VISION_BACKBONE.POS_ENC = 'sine' # sine, sineHW or learned 61 | _C.MODEL.VISION_BACKBONE.DILATION = False # If true, we replace stride with dilation in the last convolutional block (DC5) 62 | _C.MODEL.VISION_BACKBONE.FREEZE = False # If true, freeze the vision backbone parameters 63 | 64 | 65 | # ----------------------------------------------------------------------------- 66 | # Language Encoder Config 67 | # ----------------------------------------------------------------------------- 68 | _C.MODEL.TEXT_MODEL = CN() 69 | _C.MODEL.TEXT_MODEL.NAME = 'roberta-base' # "bert-base", "roberta-large" 70 | _C.MODEL.TEXT_MODEL.FREEZE = False 71 | 72 | # If true, use LSTM as the text encoder 73 | _C.MODEL.USE_LSTM = False 74 | _C.MODEL.LSTM = CN() 75 | _C.MODEL.LSTM.NAME = 'lstm' 76 | _C.MODEL.LSTM.HIDDEN_SIZE = 512 77 | _C.MODEL.LSTM.BIDIRECTIONAL = True 78 | _C.MODEL.LSTM.DROPOUT = 0 79 | _C.MODEL.LSTM_NUM_LAYERS = 2 80 | 81 | 82 | # ----------------------------------------------------------------------------- 83 | # CG Pipeline Config 84 | # ----------------------------------------------------------------------------- 85 | _C.MODEL.CG = CN() 86 | _C.MODEL.CG.HIDDEN = 256 87 | _C.MODEL.CG.QUERY_DIM = 4 # the anchor dim 88 | _C.MODEL.CG.ENC_LAYERS = 6 89 | _C.MODEL.CG.DEC_LAYERS = 6 90 | _C.MODEL.CG.FFN_DIM = 2048 91 | _C.MODEL.CG.DROPOUT = 0.1 92 | _C.MODEL.CG.HEADS = 8 93 | _C.MODEL.CG.USE_LEARN_TIME_EMBED = False 94 | _C.MODEL.CG.USE_ACTION = True # use the actioness head by default 95 | _C.MODEL.CG.FROM_SCRATCH = True 96 | 97 | # For 2D-Map prediction 98 | _C.MODEL.CG.TEMP_PRED_LAYERS = 6 99 | _C.MODEL.CG.CONV_LAYERS = 4 100 | _C.MODEL.CG.TEMP_HEAD = 'attn' # attn or conv 101 | _C.MODEL.CG.KERNAL_SIZE = 9 102 | _C.MODEL.CG.MAX_MAP_SIZE = 128 103 | _C.MODEL.CG.POOLING_COUNTS = [15,8,8,8] 104 | 105 | _C.MODEL.CG.TEMP_THETA = 0. 106 | _C.MODEL.CG.SPAT_GT_THETA = 0. 107 | _C.MODEL.CG.SPAT_THETA = 0. 108 | 109 | # ----------------------------------------------------------------------------- 110 | # DATASET related params 111 | # ----------------------------------------------------------------------------- 112 | _C.DATASET = CN() 113 | _C.DATASET.NAME = 'VidSTG' 114 | _C.DATASET.NUM_CLIP_FRAMES = 32 115 | # The minimum gt frames in a sampled clip 116 | _C.DATASET.MIN_GT_FRAME = 4 117 | 118 | 119 | # ----------------------------------------------------------------------------- 120 | # DataLoader 121 | # ----------------------------------------------------------------------------- 122 | _C.DATALOADER = CN() 123 | # Number of data loading threads 124 | _C.DATALOADER.NUM_WORKERS = 4 125 | _C.DATALOADER.SIZE_DIVISIBILITY = 0 126 | _C.DATALOADER.ASPECT_RATIO_GROUPING = False 127 | 128 | # ---------------------------------------------------------------------------- # 129 | # Solver 130 | # ---------------------------------------------------------------------------- # 131 | _C.SOLVER = CN() 132 | _C.SOLVER.MAX_EPOCH = 30 133 | _C.SOLVER.BATCH_SIZE = 1 # The video number per GPU, should be set 1. 134 | _C.SOLVER.SHUFFLE = True 135 | _C.SOLVER.BASE_LR = 2e-5 136 | _C.SOLVER.VIS_BACKBONE_LR = 1e-5 137 | _C.SOLVER.TEXT_LR = 2e-5 138 | _C.SOLVER.TEMP_LR = 1e-4 139 | _C.SOLVER.OPTIMIZER = 'adamw' 140 | _C.SOLVER.MAX_GRAD_NORM = 0.1 141 | 142 | # loss weight hyper-parameter 143 | _C.SOLVER.BBOX_COEF = 5 144 | _C.SOLVER.GIOU_COEF = 2 145 | _C.SOLVER.TEMP_COEF = 2 146 | _C.SOLVER.ATTN_COEF = 1 147 | _C.SOLVER.ACTIONESS_COEF = 2 148 | _C.SOLVER.CONF_COEF = 1 149 | 150 | _C.SOLVER.MOMENTUM = 0.9 151 | _C.SOLVER.WEIGHT_DECAY = 0.0001 152 | _C.SOLVER.GAMMA = 0.1 153 | _C.SOLVER.POWER = 0.9 # For Poly LRScheduler 154 | _C.SOLVER.STEPS = (30000,) 155 | 156 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 157 | _C.SOLVER.WARMUP_ITERS = 500 158 | 159 | _C.SOLVER.WARMUP_PROP = 0.01 160 | _C.SOLVER.WARMUP_METHOD = "linear" 161 | 162 | _C.SOLVER.SCHEDULE = CN() 163 | _C.SOLVER.SCHEDULE.TYPE = "linear_with_warmup" 164 | _C.SOLVER.SCHEDULE.DROP_STEP = [8,12] 165 | 166 | # the following paramters are only used for WarmupReduceLROnPlateau 167 | _C.SOLVER.SCHEDULE.PATIENCE = 2 168 | _C.SOLVER.SCHEDULE.THRESHOLD = 1e-4 169 | _C.SOLVER.SCHEDULE.COOLDOWN = 1 170 | _C.SOLVER.SCHEDULE.FACTOR = 0.5 171 | _C.SOLVER.SCHEDULE.MAX_DECAY_STEP = 7 172 | 173 | _C.SOLVER.PRE_VAL = False 174 | _C.SOLVER.TO_VAL = True 175 | _C.SOLVER.VAL_PERIOD = 3000 # every 10% training iterations completed, start a avaluation 176 | _C.SOLVER.CHECKPOINT_PERIOD = 5000 177 | 178 | 179 | _C.SOLVER.USE_ATTN = False # whether to use the guided attention loss, to compare with TubeDETR 180 | _C.SOLVER.SIGMA = 2.0 # standard deviation for the quantized gaussian law used for the kullback leibler divergence loss 181 | _C.SOLVER.USE_AUX_LOSS = True # whether to use auxiliary decoding losses (loss at each layer) 182 | _C.SOLVER.EOS_COEF = 0.1 # The coeff for negative sample -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import random 3 | 4 | import torch 5 | import torchvision 6 | import random 7 | from torchvision.transforms import functional as F 8 | import torchvision.transforms as T 9 | from utils.bounding_box import BoxList 10 | 11 | 12 | class Compose(object): 13 | def __init__(self, transforms): 14 | self.transforms = transforms 15 | 16 | def __call__(self, input_dict): 17 | for t in self.transforms: 18 | input_dict = t(input_dict) 19 | return input_dict 20 | 21 | def __repr__(self): 22 | format_string = self.__class__.__name__ + "(" 23 | for t in self.transforms: 24 | format_string += "\n" 25 | format_string += " {0}".format(t) 26 | format_string += "\n)" 27 | return format_string 28 | 29 | 30 | class ColorJitter(object): 31 | def __init__(self,brightness=0,contrast=0,saturation=0,hue=0): 32 | self.color_jitter = torchvision.transforms.ColorJitter( 33 | brightness=brightness, 34 | contrast=contrast, 35 | saturation=saturation, 36 | hue=hue,) 37 | 38 | def __call__(self, input_dict): 39 | if random.random() < 0.8: 40 | frames = input_dict['frames'] 41 | frames = self.color_jitter(frames) 42 | input_dict['frames'] = frames 43 | 44 | return input_dict 45 | 46 | 47 | class RandomHorizontalFlip(object): 48 | def __init__(self, prob=0.5): 49 | self.prob = prob 50 | 51 | def __call__(self, input_dict): 52 | if random.random() < self.prob: 53 | frames = input_dict['frames'] 54 | boxs = input_dict['boxs'] 55 | text = input_dict['text'] 56 | 57 | frames = F.hflip(frames) 58 | boxs = boxs.transpose(0) 59 | text = text.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left') 60 | 61 | input_dict['frames'] = frames 62 | input_dict['boxs'] = boxs 63 | input_dict['text'] = text 64 | 65 | return input_dict 66 | 67 | 68 | class RandomSelect(object): 69 | """ 70 | Randomly selects between transforms1 and transforms2, 71 | with probability p for transforms1 and (1 - p) for transforms2 72 | """ 73 | 74 | def __init__(self, transforms1, transforms2, p=0.5): 75 | self.transforms1 = transforms1 76 | self.transforms2 = transforms2 77 | self.p = p 78 | 79 | def __call__(self, input_dict): 80 | # if random.random() < self.p: 81 | # return self.transforms1(input_dict) 82 | return self.transforms1(input_dict) 83 | 84 | 85 | class RandomResize(object): 86 | def __init__(self, min_size, max_size=None): 87 | if not isinstance(min_size, (list, tuple)): 88 | min_size = (min_size,) 89 | self.min_size = min_size 90 | self.max_size = max_size 91 | 92 | def get_size(self, image_size): 93 | h, w = image_size 94 | size = random.choice(self.min_size) 95 | max_size = self.max_size 96 | if max_size is not None: 97 | min_original_size = float(min((w, h))) 98 | max_original_size = float(max((w, h))) 99 | if max_original_size / min_original_size * size > max_size: 100 | size = int(round(max_size * min_original_size / max_original_size)) 101 | 102 | if (w <= h and w == size) or (h <= w and h == size): 103 | return (h, w) 104 | 105 | if w < h: 106 | ow = size 107 | oh = int(size * h / w) 108 | else: 109 | oh = size 110 | ow = int(size * w / h) 111 | 112 | return (oh, ow) 113 | 114 | def __call__(self, input_dict): 115 | frames = input_dict['frames'] 116 | boxs = input_dict['boxs'] 117 | img_size = (frames.shape[2],frames.shape[3]) 118 | size = (frames.size(2), frames.size(3)) # self.get_size(img_size) 119 | 120 | frames = F.resize(frames, size) 121 | boxs = boxs.resize((size[1],size[0])) 122 | input_dict['frames'] = frames 123 | input_dict['boxs'] = boxs 124 | 125 | return input_dict 126 | 127 | 128 | class RandomSizeCrop(object): 129 | def __init__(self, min_size: int, max_size: int, max_try: int=50): 130 | self.min_size = min_size 131 | self.max_size = max_size 132 | self.max_try = max_try 133 | 134 | def __call__(self, input_dict): 135 | frames = input_dict['frames'] 136 | boxs = input_dict['boxs'] 137 | 138 | for _ in range(self.max_try): 139 | h = frames.shape[2] 140 | w = frames.shape[3] 141 | tw = random.randint(self.min_size, min(w, self.max_size)) 142 | th = random.randint(self.min_size, min(h, self.max_size)) 143 | 144 | region = T.RandomCrop.get_params(frames, [th, tw]) # [i, j, th, tw] 145 | if boxs.check_crop_valid(region): 146 | frames = F.crop(frames, *region) 147 | boxs = boxs.crop(region) 148 | input_dict['frames'] = frames 149 | input_dict['boxs'] = boxs 150 | return input_dict 151 | 152 | return input_dict 153 | 154 | 155 | class Normalize(object): 156 | def __init__(self, mean, std): 157 | self.mean = mean 158 | self.std = std 159 | 160 | def __call__(self, input_dict): 161 | frames = input_dict['frames'] 162 | boxs = input_dict['boxs'] 163 | frames = F.normalize(frames, mean=self.mean, std=self.std) 164 | assert boxs.size == (frames.shape[3],frames.shape[2]) # (w, h) 165 | boxs = boxs.normalize() 166 | input_dict['frames'] = frames 167 | input_dict['boxs'] = boxs 168 | return input_dict 169 | 170 | 171 | class NormalizeAndPad(object): 172 | def __init__(self, mean, std, size, aug_translate=False): 173 | self.mean = mean 174 | self.std = std 175 | self.size = size 176 | self.aug_translate = aug_translate 177 | 178 | def __call__(self, input_dict): 179 | frames = input_dict['frames'] 180 | frames = F.normalize(frames, mean=self.mean, std=self.std) 181 | 182 | t, _, h, w = frames.shape 183 | dw = self.size - w 184 | dh = self.size - h 185 | 186 | if self.aug_translate: 187 | top = random.randint(0, dh) 188 | left = random.randint(0, dw) 189 | else: 190 | top = round(dh / 2.0 - 0.1) 191 | left = round(dw / 2.0 - 0.1) 192 | 193 | out_frames = torch.zeros((t,3,self.size,self.size)).float() 194 | out_mask = torch.ones((self.size, self.size)).int() 195 | 196 | out_frames[:, :, top:top+h, left:left+w] = frames 197 | out_mask[top:top+h, left:left+w] = 0 198 | 199 | input_dict['frames'] = out_frames 200 | input_dict['mask'] = out_mask 201 | 202 | if 'boxs' in input_dict.keys(): 203 | boxs = input_dict['boxs'] 204 | boxs = boxs.shift((self.size,self.size),left,top) 205 | input_dict['boxs'] = boxs 206 | 207 | return input_dict -------------------------------------------------------------------------------- /models/grounding_model/modal_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import Tensor, nn 5 | from typing import List, Optional, Tuple 6 | 7 | from utils.misc import NestedTensor 8 | from .position_encoding import SeqEmbeddingLearned, SeqEmbeddingSine 9 | 10 | 11 | class CrossModalEncoder(nn.Module): 12 | 13 | def __init__(self, cfg): 14 | super().__init__() 15 | # attention configuration 16 | d_model = cfg.MODEL.CG.HIDDEN 17 | nhead = cfg.MODEL.CG.HEADS 18 | dim_feedforward = cfg.MODEL.CG.FFN_DIM 19 | dropout = cfg.MODEL.CG.DROPOUT 20 | activation = "relu" 21 | num_layers = cfg.MODEL.CG.ENC_LAYERS 22 | self.d_model = d_model 23 | 24 | encoder_layer = TransformerEncoderLayer( 25 | d_model, nhead, dim_feedforward, dropout, activation 26 | ) 27 | encoder_norm = None 28 | self.encoder = SpatialTemporalEncoder(cfg, encoder_layer, num_layers, encoder_norm) 29 | self.fusion = nn.Linear(d_model, d_model) 30 | 31 | # The position embedding for feature map 32 | # self.spatial_embed = PositionEmbeddingLearned(d_model // 2) 33 | self._reset_parameters() 34 | 35 | def _reset_parameters(self): 36 | for p in self.parameters(): 37 | if p.dim() > 1: 38 | nn.init.xavier_uniform_(p) 39 | 40 | def forward(self, videos: NestedTensor = None, vis_pos=None, texts: Tuple = None, vid_features=None): 41 | vis_features, vis_mask, vis_durations = videos.decompose() 42 | assert vis_pos.shape[0] == sum(vis_durations), "{} != {}".format(vis_pos.shape[0], sum(vis_durations)) 43 | 44 | vis_mask[:, 0, 0] = False # avoid empty masks 45 | 46 | _, _, H, W = vis_features.shape 47 | # n_frames x c x h x w => hw x n_frames x c 48 | vis_features = vis_features.flatten(2).permute(2, 0, 1) # torch.Size([156, 64, 256]) 49 | vid_features = vid_features.flatten(2).permute(2, 0, 1) 50 | vis_pos = vis_pos.flatten(2).permute(2, 0, 1) 51 | vis_mask = vis_mask.flatten(1) 52 | 53 | # prepare the text encodings 54 | text_mask, text_features, _ = texts 55 | 56 | # expand the attention mask and text token from [b, len] to [n_frames, len] 57 | frame_length = vis_durations[0] 58 | text_mask = text_mask.expand(frame_length, text_mask.size(-1)) 59 | text_features = text_features.expand(text_features.size(0), frame_length, text_features.size(-1)) # [text_len, n_frames, d_model] 60 | 61 | # concat visual and text features and Pad the vis_pos with 0 for the text tokens 62 | features = torch.cat([vis_features, text_features, vid_features], dim=0) 63 | mask = torch.cat([vis_mask, text_mask, vis_mask], dim=1) 64 | vis_pos = torch.cat([vis_pos, torch.zeros_like(text_features), vis_pos], dim=0) 65 | 66 | # perfrom cross-modality interaction 67 | encoded_feature, frames_cls, videos_cls = self.encoder( 68 | features, 69 | src_key_padding_mask=mask, 70 | pos=vis_pos, 71 | ) 72 | 73 | memory_cache = { 74 | "encoded_feature": encoded_feature, # 75 | "encoded_mask": mask, # batch first 76 | "frames_cls" : frames_cls, # n_frame, d_model 77 | "videos_cls" : videos_cls, # b , d_model 78 | "durations": vis_durations, 79 | "fea_map_size": (H, W) 80 | } 81 | 82 | return memory_cache 83 | 84 | 85 | class SpatialTemporalEncoder(nn.Module): 86 | def __init__(self, cfg, encoder_layer, num_layers, norm=None, return_weights=False): 87 | super().__init__() 88 | self.spatial_layers = _get_clones(encoder_layer, num_layers) 89 | self.temporal_layers = _get_clones(encoder_layer, num_layers) 90 | video_max_len = cfg.INPUT.MAX_VIDEO_LEN 91 | d_model = cfg.MODEL.CG.HIDDEN 92 | self.d_model = d_model 93 | 94 | # The position embedding of global tokens 95 | if cfg.MODEL.CG.USE_LEARN_TIME_EMBED: 96 | self.time_embed = SeqEmbeddingLearned(video_max_len + 1 , d_model) 97 | else: 98 | self.time_embed = SeqEmbeddingSine(video_max_len + 1, d_model) 99 | 100 | # The position embedding of local frame tokens 101 | self.local_pos_embed = nn.Embedding(1, d_model) # the learned pos embed for frame cls token 102 | 103 | # The learnd local and global embedding 104 | self.frame_cls = nn.Embedding(1, d_model) # the frame level local cls token 105 | self.video_cls = nn.Embedding(1, d_model) # the video level global cls token 106 | 107 | self.num_layers = num_layers 108 | self.norm = nn.LayerNorm(d_model) 109 | self.return_weights = return_weights 110 | 111 | def forward( 112 | self, 113 | src, 114 | mask: Optional[Tensor] = None, 115 | src_key_padding_mask: Optional[Tensor] = None, 116 | pos: Optional[Tensor] = None, 117 | ): 118 | output = src 119 | 120 | for i_layer, layer in enumerate(self.spatial_layers): 121 | # spatial interaction on each single frame 122 | output = layer( 123 | output, 124 | src_mask=mask, 125 | src_key_padding_mask=src_key_padding_mask, 126 | pos=pos, 127 | ) 128 | 129 | 130 | if self.norm is not None: 131 | output = self.norm(output) 132 | 133 | frame_src = torch.mean(output, dim=0) 134 | video_src = torch.mean(frame_src, dim=0) 135 | return output, frame_src, video_src 136 | 137 | 138 | class TransformerEncoderLayer(nn.Module): 139 | def __init__( 140 | self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu" 141 | ): 142 | super().__init__() 143 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 144 | # Implementation of Feedforward model 145 | self.linear1 = nn.Linear(d_model, dim_feedforward) 146 | self.dropout = nn.Dropout(dropout) 147 | self.linear2 = nn.Linear(dim_feedforward, d_model) 148 | 149 | self.norm1 = nn.LayerNorm(d_model) 150 | self.norm2 = nn.LayerNorm(d_model) 151 | self.dropout1 = nn.Dropout(dropout) 152 | self.dropout2 = nn.Dropout(dropout) 153 | 154 | self.activation = _get_activation_fn(activation) 155 | 156 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 157 | return tensor if pos is None else tensor + pos 158 | 159 | def forward( 160 | self, 161 | src, 162 | src_mask: Optional[Tensor] = None, 163 | src_key_padding_mask: Optional[Tensor] = None, 164 | pos: Optional[Tensor] = None, 165 | ): 166 | q = k = self.with_pos_embed(src, pos) 167 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 168 | src = src + self.dropout1(src2) 169 | src = self.norm1(src) 170 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 171 | src = src + self.dropout2(src2) 172 | src = self.norm2(src) 173 | return src 174 | 175 | 176 | def _get_clones(module, N): 177 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 178 | 179 | 180 | def _get_activation_fn(activation): 181 | """Return an activation function given a string""" 182 | if activation == "relu": 183 | return F.relu 184 | if activation == "gelu": 185 | return F.gelu 186 | if activation == "glu": 187 | return F.glu 188 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 189 | -------------------------------------------------------------------------------- /models/bert_model/bert_module.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import copy 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from easydict import EasyDict as edict 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def gelu(x): 14 | """Implementation of the gelu activation function. 15 | For information: OpenAI GPT"s gelu is slightly different (and gives slightly different results): 16 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 17 | Also see https://arxiv.org/abs/1606.08415 18 | """ 19 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 20 | 21 | def gelu_new(x): 22 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 23 | Also see https://arxiv.org/abs/1606.08415 24 | """ 25 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 26 | 27 | class BertLayerNorm(nn.Module): 28 | def __init__(self, hidden_size, eps=1e-12): 29 | """Construct a layernorm module in the TF style (epsilon inside the square root). 30 | """ 31 | super(BertLayerNorm, self).__init__() 32 | self.weight = nn.Parameter(torch.ones(hidden_size)) 33 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 34 | self.variance_epsilon = eps 35 | 36 | def forward(self, x): 37 | u = x.mean(-1, keepdim=True) 38 | s = (x - u).pow(2).mean(-1, keepdim=True) 39 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 40 | return self.weight * x + self.bias 41 | 42 | 43 | class BertSelfAttention(nn.Module): 44 | def __init__(self, config): 45 | super(BertSelfAttention, self).__init__() 46 | if config.hidden_size % config.num_attention_heads != 0: 47 | raise ValueError( 48 | "The hidden size (%d) is not a multiple of the number of attention " 49 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 50 | self.num_attention_heads = config.num_attention_heads 51 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 52 | self.all_head_size = self.num_attention_heads * self.attention_head_size 53 | 54 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 55 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 56 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 57 | 58 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 59 | 60 | def transpose_for_scores(self, x): 61 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh) 62 | x = x.view(*new_x_shape) 63 | return x.permute(0, 2, 1, 3) # (N, nh, L, dh) 64 | 65 | def forward(self, query_states, key_states, value_states, attention_mask=None): 66 | """ 67 | Args: 68 | query_states: (N, Lq, D) 69 | key_states: (N, L, D) 70 | value_states: (N, L, D) 71 | attention_mask: (N, Lq, L) 72 | 73 | Returns: 74 | 75 | """ 76 | # only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last) 77 | # will be ignored in future computation anyway 78 | if attention_mask is not None: 79 | attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L) 80 | mixed_query_layer = self.query(query_states) 81 | mixed_key_layer = self.key(key_states) 82 | mixed_value_layer = self.value(value_states) 83 | 84 | query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh) 85 | key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh) 86 | value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh) 87 | 88 | # Take the dot product between "query" and "key" to get the raw attention scores. 89 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L) 90 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 91 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 92 | if attention_mask is not None: 93 | attention_scores = attention_scores + attention_mask 94 | 95 | # Normalize the attention scores to probabilities. 96 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 97 | 98 | # This is actually dropping out entire tokens to attend to, which might 99 | # seem a bit unusual, but is taken from the original Transformer paper. 100 | attention_probs = self.dropout(attention_probs) 101 | 102 | context_layer = torch.matmul(attention_probs, value_layer) 103 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 104 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 105 | context_layer = context_layer.view(*new_context_layer_shape) 106 | return context_layer 107 | 108 | 109 | class BertSelfOutput(nn.Module): 110 | def __init__(self, config): 111 | super(BertSelfOutput, self).__init__() 112 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 113 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 114 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 115 | 116 | def forward(self, hidden_states, input_tensor): 117 | hidden_states = self.dense(hidden_states) 118 | hidden_states = self.dropout(hidden_states) 119 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 120 | return hidden_states 121 | 122 | class BertIntermediate(nn.Module): 123 | def __init__(self, config): 124 | super(BertIntermediate, self).__init__() 125 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 126 | self.intermediate_act_fn = gelu 127 | 128 | def forward(self, hidden_states): 129 | hidden_states = self.dense(hidden_states) 130 | hidden_states = self.intermediate_act_fn(hidden_states) 131 | return hidden_states 132 | 133 | 134 | class BertOutput(nn.Module): 135 | def __init__(self, config): 136 | super(BertOutput, self).__init__() 137 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 138 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 139 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 140 | 141 | def forward(self, hidden_states, input_tensor): 142 | hidden_states = self.dense(hidden_states) 143 | hidden_states = self.dropout(hidden_states) 144 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 145 | return hidden_states 146 | 147 | 148 | class BertAttention_Cross(nn.Module): 149 | def __init__(self, config): 150 | super(BertAttention_Cross, self).__init__() 151 | self.self = BertSelfAttention(config) 152 | self.output = BertSelfOutput(config) 153 | 154 | def forward(self, q, kv): 155 | self_output = self.self(q, kv, kv) 156 | attention_output = self.output(self_output, q) 157 | return attention_output 158 | 159 | 160 | class BertLayer_Cross(nn.Module): 161 | def __init__(self, config): 162 | super(BertLayer_Cross, self).__init__() 163 | self.config = config 164 | self.attention = BertAttention_Cross(config) 165 | self.hidden_intermediate = BertIntermediate(config) 166 | self.memory_intermediate = BertIntermediate(config) 167 | self.output = BertOutput(config) 168 | 169 | def forward(self, q, kv): 170 | attention_output = self.attention(q, kv) # (N, L, D) 171 | intermediate_output = self.hidden_intermediate(attention_output) # (N, L, D) 172 | layer_output = self.output(intermediate_output, attention_output) # (N, L, D) 173 | return layer_output -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import os 3 | from copy import deepcopy 4 | 5 | import torch 6 | from torch.hub import load_state_dict_from_url 7 | from utils.comm import is_main_process 8 | 9 | import ssl 10 | ssl._create_default_https_context = ssl._create_unverified_context 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | class VSTGCheckpointer(object): 27 | def __init__( 28 | self, 29 | cfg, 30 | model, 31 | model_ema=None, 32 | optimizer=None, 33 | save_dir="", 34 | save_to_disk=None, 35 | logger=None, 36 | is_train=True 37 | ): 38 | self.cfg = cfg 39 | self.model = model 40 | self.model_ema = model_ema 41 | self.optimizer = optimizer 42 | self.save_dir = save_dir 43 | self.save_to_disk = save_to_disk 44 | self.logger = logger 45 | self.is_train = is_train 46 | 47 | def save(self, name, **kwargs): 48 | if not self.save_dir: 49 | return 50 | 51 | if not self.save_to_disk: 52 | return 53 | 54 | data = {} 55 | data["model"] = self.model.state_dict() 56 | if self.model_ema is not None: 57 | data["model_ema"] = self.model_ema.state_dict() 58 | if self.optimizer is not None: 59 | data["optimizer"] = self.optimizer.state_dict() 60 | data.update(kwargs) 61 | 62 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 63 | self.logger.info("Saving checkpoint to {}".format(save_file)) 64 | torch.save(data, save_file) 65 | 66 | self.tag_last_checkpoint(save_file) 67 | 68 | def load(self, f=None, with_optim=True, load_mapping={}): 69 | if self.has_checkpoint() and self.is_train: 70 | # override argument with existing checkpoint 71 | f = self.get_checkpoint_file() 72 | if not f: 73 | # no checkpoint could be found 74 | self.logger.info("No checkpoint found. Initializing model from ImageNet") 75 | return {} 76 | 77 | self.logger.info("Loading checkpoint from {}".format(f)) 78 | checkpoint = self._load_file(f) 79 | self._load_model(checkpoint) 80 | 81 | if with_optim: 82 | if "optimizer" in checkpoint and self.optimizer: 83 | self.logger.info("Loading optimizer from {}".format(f)) 84 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 85 | 86 | # return any further checkpoint data 87 | return checkpoint 88 | 89 | def has_checkpoint(self): 90 | save_file = os.path.join(self.save_dir, "last_checkpoint") 91 | return os.path.exists(save_file) 92 | 93 | def get_checkpoint_file(self): 94 | save_file = os.path.join(self.save_dir, "last_checkpoint") 95 | try: 96 | with open(save_file, "r") as f: 97 | last_saved = f.read() 98 | last_saved = last_saved.strip() 99 | except IOError: 100 | # if file doesn't exist, maybe because it has just been 101 | # deleted by a separate process 102 | last_saved = "" 103 | return last_saved 104 | 105 | def tag_last_checkpoint(self, last_filename): 106 | save_file = os.path.join(self.save_dir, "last_checkpoint") 107 | with open(save_file, "w") as f: 108 | f.write(last_filename) 109 | 110 | def _load_file(self, f): 111 | # download url files 112 | if f.startswith("http"): 113 | # if the file is a url path, download it and cache it 114 | self.logger.info("loading checking point from {}".format(f)) 115 | loaded = load_state_dict_from_url(model_urls[self.cfg.MODEL.RESNETS.NAME]) 116 | else: 117 | # load native pytorch checkpoint 118 | loaded = torch.load(f, map_location=torch.device("cpu")) 119 | 120 | return loaded 121 | 122 | def _load_mdetr_weight(self, weight_dict): 123 | load_mapping = {} 124 | current_keys = sorted(list(self.model.state_dict().keys())) 125 | 126 | for cur_key in current_keys: 127 | 128 | if cur_key.startswith('vis_encoder'): 129 | load_mapping[cur_key] = cur_key.replace('vis_encoder', 'backbone') 130 | 131 | if cur_key.startswith('text_encoder'): 132 | module_names = cur_key.split('.') 133 | if 'body' in module_names: 134 | module_names.remove('body') 135 | else: 136 | module_names.remove('text_encoder') 137 | 138 | module_names.insert(0,'transformer') 139 | load_mapping[cur_key] = '.'.join(module_names) 140 | 141 | if cur_key.startswith('input_proj'): 142 | load_mapping[cur_key] = cur_key 143 | 144 | if cur_key.startswith('bbox_embed'): 145 | load_mapping[cur_key] = cur_key 146 | 147 | if cur_key.startswith('ground_encoder'): 148 | # ground_encoder.encoder.spatial_layers 149 | module_names = cur_key.split('.') 150 | if "spatial_layers" in module_names: 151 | module_names.remove("ground_encoder") 152 | module_names.insert(0,'transformer') 153 | module_names.remove("spatial_layers") 154 | module_names.insert(2,'layers') 155 | load_mapping[cur_key] = '.'.join(module_names) 156 | 157 | if cur_key.startswith('ground_decoder'): 158 | module_names = cur_key.split('.') 159 | module_names.remove("ground_decoder") 160 | module_names.insert(0,'transformer') 161 | load_mapping[cur_key] = '.'.join(module_names) 162 | 163 | loaded_dict = {} 164 | for key in load_mapping: 165 | if load_mapping[key] in weight_dict.keys(): 166 | loaded_dict[key] = weight_dict[load_mapping[key]] 167 | 168 | # for key in current_keys: 169 | # if key not in loaded_dict.keys(): 170 | # print(key) 171 | 172 | self.model.load_state_dict(loaded_dict, strict=False) 173 | 174 | def _load_pretrained(self,state_dict): 175 | model_key = 'model' 176 | if "model_ema" in state_dict: 177 | model_key = 'model_ema' 178 | 179 | if self.is_train: 180 | # Initialized with the pretrained model weight 181 | self._load_mdetr_weight(state_dict[model_key]) 182 | if 'args' in state_dict.keys(): 183 | state_dict.pop('args') 184 | if 'epoch' in state_dict.keys(): 185 | state_dict.pop('epoch') 186 | if 'optimizer' in state_dict.keys(): 187 | state_dict.pop('optimizer') 188 | else: 189 | # Used For Evaluation and Inference, Load trained Checkpoint 190 | self.model.load_state_dict(state_dict[model_key]) 191 | if (self.cfg.MODEL.EMA) and (self.model_ema is not None): 192 | self.model_ema.load_state_dict(deepcopy(self.model).state_dict()) 193 | 194 | def _load_model(self, checkpoint): 195 | if self.is_train and self.has_checkpoint(): # resume training 196 | self.model.load_state_dict(checkpoint["model"]) 197 | if (self.cfg.MODEL.EMA) and (self.model_ema is not None): 198 | if 'model_ema' not in checkpoint: 199 | self.model_ema.load_state_dict(deepcopy(self.model).state_dict()) 200 | else: 201 | self.model_ema.load_state_dict(checkpoint["model_ema"]) 202 | else: 203 | self._load_pretrained(checkpoint) 204 | if 'model_ema' in checkpoint: 205 | checkpoint.pop('model_ema') 206 | checkpoint.pop('model') -------------------------------------------------------------------------------- /models/map2d_head.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | from typing import List, Optional 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | 8 | 9 | class Gen2DMap(nn.Module): 10 | def __init__(self, cfg): 11 | super().__init__() 12 | N = cfg.MODEL.TEMPFORMER.MAX_MAP_SIZE 13 | pooling_counts = cfg.MODEL.TEMPFORMER.POOLING_COUNTS 14 | self.map_size = N 15 | mask2d = torch.zeros(N, N, dtype=torch.bool) 16 | mask2d[range(N), range(N)] = 1 17 | 18 | stride, offset = 1, 0 19 | maskij = [] 20 | for c in pooling_counts: 21 | for _ in range(c): 22 | # fill a diagonal line 23 | offset += stride 24 | i, j = range(0, N - offset, stride), range(offset, N, stride) 25 | mask2d[i, j] = 1 26 | maskij.append((i, j)) 27 | stride *= 2 28 | 29 | poolers = [nn.MaxPool1d(2,1) for _ in range(pooling_counts[0])] 30 | for c in pooling_counts[1:]: 31 | poolers.extend( 32 | [nn.MaxPool1d(3,2)] + [nn.MaxPool1d(2,1) for _ in range(c - 1)] 33 | ) 34 | 35 | self.mask2d = mask2d.to("cuda") 36 | self.maskij = maskij 37 | self.poolers = poolers 38 | 39 | def forward(self, x): 40 | """ 41 | input : 42 | tensor : [batch, n_frames, dim] 43 | output : 44 | 2D map : [batch, n_frames, n_frames, dim] 45 | mask : [batch, n_frames, n_frames, 1] 46 | """ 47 | x = x.permute(0,2,1) 48 | b, d_model, n_frames = x.shape 49 | 50 | if n_frames > self.map_size: 51 | x = F.adaptive_avg_pool1d(x, self.map_size) 52 | 53 | x = F.adaptive_max_pool1d(x, self.map_size) 54 | 55 | N = self.map_size 56 | map2d = x.new_zeros(b, d_model, N, N) 57 | map2d[:, :, range(N), range(N)] = x 58 | for pooler, (i, j) in zip(self.poolers, self.maskij): 59 | x = pooler(x) 60 | map2d[:, :, i, j] = x 61 | 62 | return map2d 63 | 64 | 65 | class TempPredictionHead(nn.Module): 66 | """The Temporal Interaction Head""" 67 | 68 | def __init__(self, cfg): 69 | super().__init__() 70 | d_model = cfg.MODEL.TEMPFORMER.HIDDEN 71 | nhead = cfg.MODEL.TEMPFORMER.HEADS 72 | dim_feedforward = cfg.MODEL.TEMPFORMER.FFN_DIM 73 | dropout = cfg.MODEL.TEMPFORMER.DROPOUT 74 | num_layers = cfg.MODEL.TEMPFORMER.TEMP_PRED_LAYERS 75 | activation = "relu" 76 | self.temp_head = cfg.MODEL.TEMPFORMER.TEMP_HEAD 77 | self.map_maker = Gen2DMap(cfg) 78 | self.mask_2d = self.map_maker.mask2d 79 | 80 | self.encoder = None 81 | if self.temp_head == 'attn': 82 | encoder_layer = TransformerEncoderLayer( 83 | d_model, nhead, dim_feedforward, dropout, activation, self.mask_2d 84 | ) 85 | encoder_norm = None 86 | self.encoder = TransformerEncoder( 87 | encoder_layer, num_layers, encoder_norm 88 | ) 89 | else: 90 | kernel_size = cfg.MODEL.TEMPFORMER.KERNAL_SIZE 91 | num_conv_layers = cfg.MODEL.TEMPFORMER.CONV_LAYERS 92 | self.encoder = TempConvInteraction( 93 | d_model, kernel_size, num_conv_layers, self.mask_2d 94 | ) 95 | 96 | self._reset_parameters() 97 | self.predictor = nn.Conv2d(d_model, 1, 1) 98 | 99 | 100 | def _reset_parameters(self): 101 | for p in self.parameters(): 102 | if p.dim() > 1: 103 | nn.init.xavier_uniform_(p) 104 | 105 | def forward(self, x): 106 | """ 107 | x : [layers, b, len, d_model] 108 | """ 109 | n_layers, b , t, d_model = x.shape 110 | x = x.view(-1, t, d_model) 111 | map2d = self.map_maker(x) # n_layers * b, d_model, t, t 112 | 113 | # the segment level interaction 114 | if self.temp_head == 'attn': 115 | for i_layer in range(len(map2d)): 116 | map2d[i_layer] = self.encoder(map2d[i_layer]) 117 | else: 118 | map2d = self.encoder(map2d) 119 | 120 | scores2d = self.predictor(map2d).squeeze_() # n_layers * b, t, t 121 | _, N, N = scores2d.shape 122 | scores2d = scores2d.view(n_layers, b, N, N) 123 | 124 | if self.training: 125 | return scores2d 126 | else: 127 | return scores2d.sigmoid_() * self.mask_2d 128 | 129 | 130 | class TransformerEncoder(nn.Module): 131 | def __init__(self, encoder_layer, num_layers, norm=None): 132 | super().__init__() 133 | self.layers = _get_clones(encoder_layer, num_layers) 134 | self.num_layers = num_layers 135 | self.norm = norm 136 | 137 | def forward(self, src, pos=None): 138 | output = src 139 | for layer in self.layers: 140 | output = layer( 141 | output, 142 | pos=pos, 143 | ) 144 | 145 | if self.norm is not None: 146 | output = self.norm(output) 147 | 148 | return output 149 | 150 | 151 | class TransformerEncoderLayer(nn.Module): 152 | def __init__( 153 | self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", mask2d=None 154 | ): 155 | super().__init__() 156 | self.self_attn_row = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 157 | self.self_attn_col = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 158 | # Implementation of Feedforward model 159 | self.linear1 = nn.Linear(d_model, dim_feedforward) 160 | self.dropout = nn.Dropout(dropout) 161 | self.linear2 = nn.Linear(dim_feedforward, d_model) 162 | 163 | self.norm1 = nn.LayerNorm(d_model) 164 | self.norm2 = nn.LayerNorm(d_model) 165 | self.dropout1 = nn.Dropout(dropout) 166 | self.dropout2 = nn.Dropout(dropout) 167 | 168 | self.mask2d = mask2d 169 | self.activation = _get_activation_fn(activation) 170 | 171 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 172 | return tensor if pos is None else tensor + pos 173 | 174 | def forward( 175 | self, 176 | src, 177 | pos: Optional[Tensor] = None, 178 | ): 179 | # from d_modelxNxN to NxNxd_model 180 | src = src.permute(1,2,0) 181 | mask2d = self.mask2d 182 | # row self attention 183 | q = k = self.with_pos_embed(src, pos) 184 | src2, _ = self.self_attn_row( 185 | q, k, value=src, attn_mask=None, key_padding_mask=mask2d 186 | ) 187 | 188 | # column self attention 189 | src2 = src2.permute(1,0,2) 190 | mask2d = mask2d.permute(1,0) 191 | q = k = self.with_pos_embed(src2, pos) 192 | src2, _ = self.self_attn_col( 193 | q, k, value=src2, attn_mask=None, key_padding_mask=mask2d 194 | ) 195 | src2 = src2.permute(1,0,2) 196 | mask2d = mask2d.permute(1,0) 197 | 198 | src = src + self.dropout1(src2) 199 | src = self.norm1(src) 200 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 201 | src = src + self.dropout2(src2) 202 | src = self.norm2(src) 203 | 204 | src = src.permute(2,0,1) 205 | return src 206 | 207 | def _get_clones(module, N): 208 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 209 | 210 | 211 | def _get_activation_fn(activation): 212 | """Return an activation function given a string""" 213 | if activation == "relu": 214 | return F.relu 215 | if activation == "gelu": 216 | return F.gelu 217 | if activation == "glu": 218 | return F.glu 219 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 220 | 221 | def mask2weight(mask2d, mask_kernel, padding=0): 222 | # from the feat2d.py,we can know the mask2d is 4-d 223 | weight = torch.conv2d(mask2d[None,None,:,:].float(), 224 | mask_kernel, padding=padding)[0, 0] 225 | weight[weight > 0] = 1 / weight[weight > 0] 226 | return weight 227 | 228 | class TempConvInteraction(nn.Module): 229 | def __init__(self, hidden_size, k, num_stack_layers, mask2d): 230 | super(TempConvInteraction, self).__init__() 231 | 232 | # Padding to ensure the dimension of the output map2d 233 | mask_kernel = torch.ones(1,1,k,k).to(mask2d.device) 234 | first_padding = (k - 1) * num_stack_layers // 2 235 | 236 | self.weights = [ 237 | mask2weight(mask2d, mask_kernel, padding=first_padding) 238 | ] 239 | self.convs = nn.ModuleList( 240 | [nn.Conv2d(hidden_size, hidden_size, k, padding=first_padding)] 241 | ) 242 | 243 | for _ in range(num_stack_layers - 1): 244 | self.weights.append(mask2weight(self.weights[-1] > 0, mask_kernel)) 245 | self.convs.append(nn.Conv2d(hidden_size, hidden_size, k)) 246 | 247 | def forward(self, x): 248 | for conv, weight in zip(self.convs, self.weights): 249 | x = conv(x).relu() * weight 250 | return x 251 | 252 | 253 | if __name__ == "__main__": 254 | model = Gen2DMap(64, [15, 8, 8]) -------------------------------------------------------------------------------- /datasets/evaluation/hcstvg_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | from utils.comm import is_main_process, all_gather 6 | 7 | import torch 8 | from functools import reduce 9 | from utils.box_utils import np_box_iou 10 | import json 11 | 12 | def save_json(path, data): 13 | with open(path, "w") as f: 14 | return json.dump(data, f) 15 | 16 | class HCSTVGiouEvaluator: 17 | def __init__( 18 | self, 19 | vidstg_path: str, 20 | subset: str = "test", 21 | iou_thresholds: list = None, 22 | ): 23 | """ 24 | :param vidstg_path: path to VidSTG annotations 25 | :param subset: train, val or test 26 | :param iou_thresholds: IoU thresholds for the vIoU metrics 27 | """ 28 | assert subset in ["train", "test"], f"Wrong HCSTVG subset {subset}" 29 | 30 | gt_data = [] 31 | cache_dir = os.path.join(vidstg_path, 'data_cache') 32 | dataset_cache = os.path.join(cache_dir, f'hcstvg-{subset}-anno.cache') 33 | gt_data = torch.load(dataset_cache) 34 | 35 | self.vid2steds = {} # map video_id to [start, end] of the GT tube 36 | self.vid2box = {} # map video to bbox 37 | self.vid2names = {} 38 | self.vid2sents = {} 39 | 40 | for data_item in gt_data: 41 | video_id = data_item['item_id'] 42 | temp_gt = data_item['gt_temp_bound'] 43 | self.vid2names[video_id] = data_item['vid'] 44 | self.vid2sents[video_id] = data_item['description'] 45 | box_dict = data_item['bboxs'] 46 | self.vid2box[video_id]={key : [box_dict[key]] for key in box_dict} 47 | self.vid2steds[video_id] = temp_gt 48 | 49 | self.iou_thresholds = iou_thresholds 50 | 51 | def evaluate(self, predictions: List[Dict], video_predictions: List[Dict]): 52 | vid_metrics = {} 53 | for video_id, video_pred in video_predictions.items(): 54 | if video_id in vid_metrics: 55 | print(f"Warning, multiple predictions found for video {video_id}") 56 | continue 57 | 58 | gt_sted = self.vid2steds[video_id] 59 | pred_sted = video_pred["sted"] 60 | 61 | # compute temporal iou 62 | max_start = max(gt_sted[0], pred_sted[0]) 63 | min_end = min(gt_sted[1], pred_sted[1]) 64 | min_start = min(gt_sted[0], pred_sted[0]) 65 | max_end = max(gt_sted[1], pred_sted[1]) 66 | if min_end <= max_start: 67 | tiou = 0 68 | else: 69 | intersection = min_end - max_start 70 | gt_span = gt_sted[1] - gt_sted[0] 71 | pred_span = pred_sted[1] - pred_sted[0] 72 | union = gt_span + pred_span - intersection 73 | tiou = intersection / union 74 | 75 | # compute viou and gt_viou 76 | vid_metrics[video_id] = { 77 | "gt_sted": gt_sted, 78 | "pred_sted": pred_sted, 79 | "tiou": tiou, 80 | "img_metrics": {}, 81 | } 82 | 83 | union_predgt = set([ 84 | frame_id for frame_id in range(min_start, max_end) 85 | ]) 86 | inter_predgt = set( 87 | [frame_id for frame_id in range(max_start, min_end)] 88 | ) 89 | 90 | viou = 0 91 | gt_viou = 0 92 | prediction = predictions[video_id] 93 | 94 | for fid in self.vid2box[video_id].keys(): # iterate on all frames of the annotated moment to update GT metrics 95 | if fid not in prediction: 96 | # raise RuntimeError(f"No prediction for frame {fid}") 97 | # print(self.vid2box[video_id].keys(), fid) 98 | continue 99 | pred_boxes = prediction[fid] 100 | gt_boxes = self.vid2box[video_id][fid] 101 | iou = np_box_iou(np.array(pred_boxes), np.array(gt_boxes))[0][0] 102 | if fid in inter_predgt: 103 | viou += iou 104 | gt_viou += iou 105 | 106 | viou = viou / max(len(union_predgt), 1) 107 | vid_metrics[video_id]["viou"] = viou 108 | recalls = {thresh: 0 for thresh in self.iou_thresholds} 109 | for thresh in self.iou_thresholds: 110 | if viou > thresh: 111 | recalls[thresh] += 1 112 | vid_metrics[video_id].update( 113 | { 114 | f"viou@{thresh}": recalls[thresh] 115 | for thresh in self.iou_thresholds 116 | } 117 | ) 118 | 119 | # compute gt_viou@R 120 | gt_viou = gt_viou / max(len(self.vid2box[video_id]), 1) 121 | vid_metrics[video_id]["gt_viou"] = gt_viou 122 | gt_recalls = {thresh: 0 for thresh in self.iou_thresholds} 123 | for thresh in self.iou_thresholds: 124 | if gt_viou > thresh: 125 | gt_recalls[thresh] += 1 126 | vid_metrics[video_id].update( 127 | { 128 | f"gt_viou@{thresh}": gt_recalls[thresh] 129 | for thresh in self.iou_thresholds 130 | } 131 | ) 132 | 133 | return vid_metrics, self.vid2names, self.vid2sents 134 | 135 | 136 | class HCSTVGEvaluator(object): 137 | def __init__( 138 | self, 139 | logger, 140 | vidstg_path, 141 | subset, 142 | iou_thresholds, 143 | save_pred=False, 144 | save_dir=None 145 | ): 146 | """ 147 | :param vidstg_path: path to VidSTG annotations 148 | :param subset: train, val or test 149 | :param iou_thresholds: IoU thresholds for the vIoU metrics 150 | :param save_pred: whether to save predictions in the output of summarize 151 | """ 152 | self.evaluator = HCSTVGiouEvaluator( 153 | vidstg_path, 154 | subset=subset, 155 | iou_thresholds=iou_thresholds, 156 | ) 157 | self.predictions = {} 158 | self.video_predictions = {} 159 | self.video_cross_attn = {} 160 | self.results = None 161 | self.iou_thresholds = iou_thresholds 162 | self.save_pred = save_pred 163 | self.save_dir = save_dir 164 | self.logger = logger 165 | 166 | self.tsa_weights = {} 167 | self.text_weights = {} 168 | self.spatial_weights = {} 169 | self.pred_sted = {} 170 | 171 | def accumulate(self): 172 | pass 173 | 174 | def update(self, predictions): 175 | self.predictions.update(predictions) 176 | 177 | def update_cross_attn(self, cross_weights): 178 | self.video_cross_attn.update(cross_weights) 179 | 180 | def video_update(self, video_predictions): 181 | self.video_predictions.update(video_predictions) 182 | 183 | def synchronize_between_processes(self): 184 | all_predictions = all_gather(self.predictions) 185 | self.predictions = reduce(lambda a, b: a.update(b) or a, all_predictions, {}) 186 | all_video_predictions = all_gather(self.video_predictions) 187 | self.video_predictions = reduce(lambda a, b: a.update(b) or a, all_video_predictions, {}) 188 | 189 | def summarize(self): 190 | if is_main_process(): 191 | self.logger.info("####### Start Calculating the metrics ########") 192 | self.results, vid2names, vid2sents = self.evaluator.evaluate( 193 | self.predictions, self.video_predictions 194 | ) 195 | 196 | metrics = {"gt_viou": 0} 197 | metrics.update({"tiou": 0, "viou": 0}) 198 | for thresh in self.iou_thresholds: # init metrics 199 | metrics[f"viou@{thresh}"] = 0 200 | metrics[f"gt_viou@{thresh}"] = 0 201 | counter = 0 202 | result_str = '' 203 | result_str += '\n' + '=' * 100 + '\n' 204 | for x in self.results.values(): # sum results 205 | metrics["tiou"] += x["tiou"] 206 | metrics["viou"] += x["viou"] 207 | metrics["gt_viou"] += x["gt_viou"] 208 | for thresh in self.iou_thresholds: 209 | metrics[f"viou@{thresh}"] += x[f"viou@{thresh}"] 210 | metrics[f"gt_viou@{thresh}"] += x[f"gt_viou@{thresh}"] 211 | counter += 1 212 | 213 | for key in metrics: # average results 214 | metrics[key] = metrics[key] / counter 215 | result_str += f"{key}: {metrics[key]:.4f}" + '\n' 216 | 217 | result_str += '=' * 100 + '\n' 218 | self.logger.info(result_str) 219 | 220 | out = {f"{name}": metrics[name] for name in metrics} 221 | 222 | if self.save_pred: 223 | out["predictions"] = self.predictions 224 | out["gt"] = self.evaluator.vid2box 225 | out["video_predictions"] = self.video_predictions 226 | out["vid_metrics"] = self.results 227 | out['vid2names'] = vid2names 228 | out['vid2sents'] = vid2sents 229 | res_path = os.path.join(self.save_dir, 'test_results.json') 230 | save_json(res_path, out) 231 | 232 | return out 233 | 234 | return None 235 | -------------------------------------------------------------------------------- /utils/bounding_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # transpose 4 | FLIP_LEFT_RIGHT = 0 5 | FLIP_TOP_BOTTOM = 1 6 | 7 | 8 | class BoxList(object): 9 | """ 10 | Usage: 11 | This class represents a set of bounding boxes. 12 | The bounding boxes are represented as a Nx4 Tensor. 13 | In order to uniquely determine the bounding boxes with respect 14 | to an image, we also store the corresponding image dimensions. 15 | """ 16 | 17 | def __init__(self, bbox, image_size, mode="xyxy"): 18 | device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") 19 | bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) 20 | if bbox.ndimension() != 2: 21 | raise ValueError( 22 | "bbox should have 2 dimensions, got {}".format(bbox.ndimension()) 23 | ) 24 | if bbox.size(-1) != 4: 25 | raise ValueError( 26 | "last dimension of bbox should have a " 27 | "size of 4, got {}".format(bbox.size(-1)) 28 | ) 29 | if mode not in ("xyxy", "xywh"): 30 | raise ValueError("mode should be 'xyxy' or 'xywh'") 31 | 32 | self.bbox = bbox 33 | self.size = image_size # (image_width, image_height) 34 | self.mode = mode 35 | 36 | def convert(self, mode): 37 | """ 38 | Args: 39 | mode : xyxy xywh 40 | """ 41 | if mode not in ("xyxy", "xywh"): 42 | raise ValueError("mode should be 'xyxy' or 'xywh'") 43 | if mode == self.mode: 44 | return self 45 | # we only have two modes, so don't need to check self.mode 46 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 47 | if mode == "xyxy": 48 | bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) 49 | bbox = BoxList(bbox, self.size, mode=mode) 50 | else: 51 | bbox = torch.cat( 52 | ((xmin + xmax) / 2, (ymin + ymax) / 2, (xmax - xmin), (ymax - ymin)), dim=-1 53 | ) 54 | bbox = BoxList(bbox, self.size, mode=mode) 55 | return bbox 56 | 57 | def _split_into_xyxy(self): 58 | if self.mode == "xyxy": 59 | xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) 60 | return xmin, ymin, xmax, ymax 61 | elif self.mode == "xywh": 62 | xc, yc, w, h = self.bbox.split(1, dim=-1) 63 | return ( 64 | xc - 0.5 * w, 65 | yc - 0.5 * h, 66 | xc + 0.5 * w, 67 | yc + 0.5 * h 68 | ) 69 | else: 70 | raise RuntimeError("Should not be here") 71 | 72 | def shift(self, padded_size, left : int, top : int): 73 | """ 74 | Returns a shifted copy of this bounding box 75 | params: 76 | left : xshift, top : yshift 77 | """ 78 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 79 | shifted_xmin, shifted_xmax = xmin + left, xmax + left 80 | shifted_ymin, shifted_ymax = ymin + top, ymax + top 81 | shifted_box = torch.cat( 82 | (shifted_xmin, shifted_ymin, shifted_xmax, shifted_ymax), dim=-1 83 | ) 84 | bbox = BoxList(shifted_box, padded_size, mode="xyxy") 85 | return bbox.convert(self.mode) 86 | 87 | 88 | def resize(self, size, *args, **kwargs): 89 | """ 90 | Returns a resized copy of this bounding box 91 | 92 | :param size: The requested size in pixels, as a 2-tuple: 93 | (width, height). 94 | """ 95 | 96 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) 97 | if ratios[0] == ratios[1]: 98 | ratio = ratios[0] 99 | scaled_box = self.bbox * ratio 100 | bbox = BoxList(scaled_box, size, mode=self.mode) 101 | return bbox 102 | 103 | ratio_width, ratio_height = ratios 104 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 105 | scaled_xmin = xmin * ratio_width 106 | scaled_xmax = xmax * ratio_width 107 | scaled_ymin = ymin * ratio_height 108 | scaled_ymax = ymax * ratio_height 109 | scaled_box = torch.cat( 110 | (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 111 | ) 112 | bbox = BoxList(scaled_box, size, mode="xyxy") 113 | 114 | return bbox.convert(self.mode) 115 | 116 | def transpose(self, method): 117 | """ 118 | Transpose bounding box (flip or rotate in 90 degree steps) 119 | :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, 120 | :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`, 121 | :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`, 122 | :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`. 123 | """ 124 | if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): 125 | raise NotImplementedError( 126 | "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" 127 | ) 128 | 129 | image_width, image_height = self.size 130 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 131 | if method == FLIP_LEFT_RIGHT: 132 | transposed_xmin = image_width - xmax 133 | transposed_xmax = image_width - xmin 134 | transposed_ymin = ymin 135 | transposed_ymax = ymax 136 | elif method == FLIP_TOP_BOTTOM: 137 | transposed_xmin = xmin 138 | transposed_xmax = xmax 139 | transposed_ymin = image_height - ymax 140 | transposed_ymax = image_height - ymin 141 | 142 | transposed_boxes = torch.cat( 143 | (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 144 | ) 145 | bbox = BoxList(transposed_boxes, self.size, mode="xyxy") 146 | return bbox.convert(self.mode) 147 | 148 | def check_crop_valid(self, region): 149 | """ 150 | box : [x_min, y_min, w, h] 151 | """ 152 | rymin, rxmin, h, w = region 153 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 154 | cropped_xmin = (xmin - rxmin).clamp(min=0, max=w) 155 | cropped_ymin = (ymin - rymin).clamp(min=0, max=h) 156 | cropped_xmax = (xmax - rxmin).clamp(min=0, max=w) 157 | cropped_ymax = (ymax - rymin).clamp(min=0, max=h) 158 | 159 | valid = not any((cropped_xmin == cropped_xmax) | (cropped_ymin == cropped_ymax)) 160 | 161 | return valid 162 | 163 | def crop(self, region): 164 | """ 165 | Cropss a rectangular region from this bounding box. The box is a 166 | 4-tuple defining the left, upper, right, and lower pixel 167 | coordinate. 168 | """ 169 | rymin, rxmin, h, w = region 170 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 171 | cropped_xmin = (xmin - rxmin).clamp(min=0, max=w) 172 | cropped_ymin = (ymin - rymin).clamp(min=0, max=h) 173 | cropped_xmax = (xmax - rxmin).clamp(min=0, max=w) 174 | cropped_ymax = (ymax - rymin).clamp(min=0, max=h) 175 | 176 | cropped_box = torch.cat( 177 | (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 178 | ) 179 | bbox = BoxList(cropped_box, (w, h), mode="xyxy") 180 | return bbox.convert(self.mode) 181 | 182 | def normalize(self): 183 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 184 | image_width, image_height = self.size 185 | xmin = xmin / image_width 186 | ymin = ymin / image_height 187 | xmax = xmax / image_width 188 | ymax = ymax / image_height 189 | normalized_bbox = torch.cat( 190 | (xmin, ymin, xmax, ymax), dim=-1 191 | ) 192 | bbox = BoxList(normalized_bbox, self.size, mode="xyxy") 193 | return bbox.convert("xywh") 194 | 195 | # Tensor-like methods 196 | def to(self, device): 197 | bbox = BoxList(self.bbox.to(device), self.size, self.mode) 198 | return bbox 199 | 200 | def __getitem__(self, item): 201 | bbox = BoxList(self.bbox[item], self.size, self.mode) 202 | return bbox 203 | 204 | def __len__(self): 205 | return self.bbox.shape[0] 206 | 207 | def area(self): 208 | box = self.bbox 209 | if self.mode == "xyxy": 210 | area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) 211 | elif self.mode == "xywh": 212 | area = box[:, 2] * box[:, 3] 213 | else: 214 | raise RuntimeError("Should not be here") 215 | 216 | return area 217 | 218 | def copy(self): 219 | return BoxList(self.bbox, self.size, self.mode) 220 | 221 | def __repr__(self): 222 | s = self.__class__.__name__ + "(" 223 | s += "num_boxes={}, ".format(len(self)) 224 | s += "image_width={}, ".format(self.size[0]) 225 | s += "image_height={}, ".format(self.size[1]) 226 | s += "mode={})".format(self.mode) 227 | return s 228 | 229 | 230 | if __name__ == "__main__": 231 | bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) 232 | s_bbox = bbox.resize((5, 5)) 233 | print(s_bbox) 234 | print(s_bbox.bbox) 235 | 236 | t_bbox = bbox.transpose(0) 237 | print(t_bbox) 238 | print(t_bbox.bbox) 239 | -------------------------------------------------------------------------------- /engine/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from bisect import bisect_right 3 | import torch 4 | from torch.optim import Optimizer 5 | 6 | 7 | # separating MultiStepLR with WarmupLR 8 | # but the current LRScheduler design doesn't allow it 9 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 10 | def __init__( 11 | self, 12 | optimizer, 13 | milestones, 14 | gamma=0.1, 15 | warmup_factor=1.0 / 3, 16 | warmup_iters=500, 17 | warmup_method="linear", 18 | last_epoch=-1, 19 | ): 20 | if not list(milestones) == sorted(milestones): 21 | raise ValueError( 22 | "Milestones should be a list of increasing integers." 23 | "Got {}".format(milestones), 24 | ) 25 | 26 | if warmup_method not in ("constant", "linear"): 27 | raise ValueError( 28 | "Only 'constant' or 'linear' warmup_method accepted" 29 | "got {}".format(warmup_method) 30 | ) 31 | self.milestones = milestones 32 | self.gamma = gamma 33 | self.warmup_factor = warmup_factor 34 | self.warmup_iters = warmup_iters 35 | self.warmup_method = warmup_method 36 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | warmup_factor = 1 40 | if self.last_epoch < self.warmup_iters: 41 | if self.warmup_method == "constant": 42 | warmup_factor = self.warmup_factor 43 | elif self.warmup_method == "linear": 44 | alpha = float(self.last_epoch) / self.warmup_iters 45 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 46 | return [ 47 | base_lr 48 | * warmup_factor 49 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 50 | for base_lr in self.base_lrs 51 | ] 52 | 53 | 54 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 55 | def __init__( 56 | self, 57 | optimizer, 58 | gamma, 59 | max_iter, 60 | warmup_factor=1.0 / 3, 61 | warmup_iters=500, 62 | warmup_method="linear", 63 | last_epoch=-1, 64 | ): 65 | if warmup_method not in ("constant", "linear"): 66 | raise ValueError( 67 | "Only 'constant' or 'linear' warmup_method accepted" 68 | "got {}".format(warmup_method) 69 | ) 70 | self.max_iter = max_iter 71 | self.gamma = gamma # The poly power 72 | self.warmup_factor = warmup_factor 73 | self.warmup_iters = warmup_iters 74 | self.warmup_method = warmup_method 75 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 76 | 77 | def get_lr(self): 78 | warmup_factor = 1 79 | if self.last_epoch < self.warmup_iters: 80 | if self.warmup_method == "constant": 81 | warmup_factor = self.warmup_factor 82 | elif self.warmup_method == "linear": 83 | alpha = float(self.last_epoch) / self.warmup_iters 84 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 85 | assert self.last_epoch >= 0 86 | return [ 87 | base_lr 88 | * warmup_factor 89 | * ((1 - self.last_epoch / self.max_iter) ** self.gamma) 90 | for base_lr in self.base_lrs 91 | ] 92 | 93 | 94 | class WarmupReduceLROnPlateau(object): 95 | def __init__( 96 | self, 97 | optimizer, 98 | gamma=0.5, 99 | warmup_factor=1.0 / 3, 100 | warmup_iters=500, 101 | warmup_method="linear", 102 | last_epoch=-1, 103 | patience=2, 104 | threshold=1e-4, 105 | cooldown=1, 106 | logger=None, 107 | ): 108 | if warmup_method not in ("constant", "linear"): 109 | raise ValueError( 110 | "Only 'constant' or 'linear' warmup_method accepted" 111 | "got {}".format(warmup_method) 112 | ) 113 | self.gamma = gamma 114 | self.warmup_factor = warmup_factor 115 | self.warmup_iters = warmup_iters 116 | self.warmup_method = warmup_method 117 | self.patience = patience 118 | self.threshold = threshold 119 | self.cooldown = cooldown 120 | self.stage_count = 0 121 | self.best = -1e12 122 | self.num_bad_epochs = 0 123 | self.under_cooldown = self.cooldown 124 | self.logger = logger 125 | 126 | # The following code is copied from Pytorch=1.2.0 127 | # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html 128 | if not isinstance(optimizer, Optimizer): 129 | raise TypeError('{} is not an Optimizer'.format( 130 | type(optimizer).__name__)) 131 | self.optimizer = optimizer 132 | if last_epoch == -1: 133 | for group in optimizer.param_groups: 134 | group.setdefault('initial_lr', group['lr']) 135 | last_epoch = 0 136 | else: 137 | for i, group in enumerate(optimizer.param_groups): 138 | if 'initial_lr' not in group: 139 | raise KeyError("param 'initial_lr' is not specified " 140 | "in param_groups[{}] when resuming an optimizer".format(i)) 141 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 142 | self.last_epoch = last_epoch 143 | 144 | self.step(last_epoch) 145 | 146 | def state_dict(self): 147 | """Returns the state of the scheduler as a :class:`dict`. 148 | 149 | It contains an entry for every variable in self.__dict__ which 150 | is not the optimizer. 151 | """ 152 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 153 | 154 | def load_state_dict(self, state_dict): 155 | """Loads the schedulers state. 156 | 157 | Arguments: 158 | state_dict (dict): scheduler state. Should be an object returned 159 | from a call to :meth:`state_dict`. 160 | """ 161 | self.__dict__.update(state_dict) 162 | 163 | def get_lr(self): 164 | warmup_factor = 1 165 | # during warming up 166 | if self.last_epoch < self.warmup_iters: 167 | if self.warmup_method == "constant": 168 | warmup_factor = self.warmup_factor 169 | elif self.warmup_method == "linear": 170 | alpha = float(self.last_epoch) / self.warmup_iters 171 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 172 | # 173 | return [ 174 | base_lr 175 | * warmup_factor 176 | * self.gamma ** self.stage_count 177 | for base_lr in self.base_lrs 178 | ] 179 | 180 | def step(self, metrics, epoch=None): 181 | if epoch is None: 182 | epoch = self.last_epoch + 1 183 | self.last_epoch = epoch 184 | 185 | # The following part is modified from ReduceLROnPlateau 186 | if metrics is None: 187 | # not conduct validation yet 188 | pass 189 | else: 190 | if float(metrics) > (self.best + self.threshold): 191 | self.best = float(metrics) 192 | self.num_bad_epochs = 0 193 | else: 194 | self.num_bad_epochs += 1 195 | 196 | if self.under_cooldown > 0: 197 | self.under_cooldown -= 1 198 | self.num_bad_epochs = 0 199 | 200 | if self.num_bad_epochs >= self.patience: 201 | if self.logger is not None: 202 | self.logger.info("Trigger Schedule Decay, RL has been reduced by factor {}".format(self.gamma)) 203 | self.stage_count += 1 # this will automatically decay the learning rate 204 | self.under_cooldown = self.cooldown 205 | self.num_bad_epochs = 0 206 | 207 | 208 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 209 | param_group['lr'] = lr 210 | 211 | 212 | def adjust_learning_rate( 213 | cfg, 214 | optimizer, 215 | curr_step: int, 216 | num_training_steps: int 217 | ): 218 | """ 219 | Adjust the lr according to the schedule. 220 | """ 221 | num_warmup_steps = round(cfg.SOLVER.WARMUP_PROP * num_training_steps) 222 | iter_per_epoch = round(num_training_steps / cfg.SOLVER.MAX_EPOCH) 223 | now_epoch = curr_step // iter_per_epoch 224 | 225 | drop_step = cfg.SOLVER.SCHEDULE.DROP_STEP 226 | 227 | if cfg.SOLVER.SCHEDULE.TYPE == "multistep_with_warmup": 228 | gamma = 0.1 ** bisect_right(drop_step, now_epoch) 229 | if curr_step < num_warmup_steps: 230 | text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps)) 231 | else: 232 | text_encoder_gamma = max( 233 | 0.0, 234 | float(num_training_steps - curr_step) 235 | / float(max(1, num_training_steps - num_warmup_steps)), 236 | ) 237 | temp_decoder_gamma = text_encoder_gamma 238 | elif cfg.SOLVER.SCHEDULE.TYPE == "multistep_with_warmup_all": 239 | if curr_step < num_warmup_steps: 240 | gamma = float(curr_step) / float(max(1, num_warmup_steps)) 241 | else: 242 | gamma = 0.1 ** bisect_right(drop_step, now_epoch) 243 | text_encoder_gamma = gamma 244 | temp_decoder_gamma = text_encoder_gamma 245 | else: 246 | raise ValueError(f"Unsupported Schedule Type : {cfg.SOLVER.SCHEDULE.TYPE}") 247 | 248 | base_lrs = [cfg.SOLVER.BASE_LR, cfg.SOLVER.VIS_BACKBONE_LR, cfg.SOLVER.TEXT_LR, cfg.SOLVER.TEMP_LR] 249 | gammas = [gamma, gamma, text_encoder_gamma, temp_decoder_gamma] 250 | assert len(optimizer.param_groups) == len(base_lrs) 251 | for param_group, lr, gamma_group in zip(optimizer.param_groups, base_lrs, gammas): 252 | param_group["lr"] = lr * gamma_group -------------------------------------------------------------------------------- /datasets/evaluation/vidstg_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | from utils.comm import is_main_process, all_gather 6 | 7 | import torch 8 | from functools import reduce 9 | from utils.box_utils import np_box_iou 10 | import json 11 | 12 | def save_json(path, data): 13 | with open(path, "w") as f: 14 | return json.dump(data, f) 15 | 16 | class VidSTGiouEvaluator: 17 | def __init__( 18 | self, 19 | vidstg_path: str, 20 | subset: str = "test", 21 | iou_thresholds: list = None, 22 | ): 23 | """ 24 | :param vidstg_path: path to VidSTG annotations 25 | :param subset: train, val or test 26 | :param iou_thresholds: IoU thresholds for the vIoU metrics 27 | """ 28 | assert subset in ["train", "test", "val"], f"Wrong VidSTG subset {subset}" 29 | 30 | gt_data = [] 31 | cache_dir = os.path.join(vidstg_path, 'data_cache') 32 | dataset_cache = os.path.join(cache_dir,f'vidstd-{subset}-anno.cache') 33 | gt_data = torch.load(dataset_cache) 34 | 35 | self.vid2steds = {} # map video_id to [start, end] of the GT tube 36 | self.vid2box = {} # map video to bbox 37 | self.vid2names = {} 38 | self.vid2sents = {} 39 | 40 | for data_item in gt_data: 41 | video_id = data_item['item_id'] 42 | temp_gt = data_item['gt_temp_bound'] 43 | self.vid2names[video_id] = data_item['vid'] 44 | self.vid2sents[video_id] = data_item['description'] 45 | box_dict = data_item['bboxs'] 46 | self.vid2box[video_id]={key : [box_dict[key]] for key in box_dict} 47 | self.vid2steds[video_id] = temp_gt 48 | 49 | self.iou_thresholds = iou_thresholds 50 | 51 | def evaluate(self, predictions: List[Dict], video_predictions: List[Dict], pred_conf: List[Dict]): 52 | vid_metrics = {} 53 | for video_id, video_pred in video_predictions.items(): 54 | if video_id in vid_metrics: 55 | print(f"Warning, multiple predictions found for video {video_id}") 56 | continue 57 | 58 | gt_sted = self.vid2steds[video_id] 59 | pred_sted = video_pred["sted"] 60 | qtype = video_pred["qtype"] 61 | 62 | # compute temporal iou 63 | max_start = max(gt_sted[0], pred_sted[0]) 64 | min_end = min(gt_sted[1], pred_sted[1]) 65 | min_start = min(gt_sted[0], pred_sted[0]) 66 | max_end = max(gt_sted[1], pred_sted[1]) 67 | if min_end <= max_start: 68 | tiou = 0 69 | else: 70 | intersection = min_end - max_start 71 | gt_span = gt_sted[1] - gt_sted[0] 72 | pred_span = pred_sted[1] - pred_sted[0] 73 | union = gt_span + pred_span - intersection 74 | tiou = intersection / union 75 | 76 | # compute viou and gt_viou 77 | vid_metrics[video_id] = { 78 | "gt_sted": gt_sted, 79 | "pred_sted": pred_sted, 80 | "tiou": tiou, 81 | "qtype": qtype, 82 | "img_metrics": {}, 83 | } 84 | 85 | union_predgt = set([ 86 | frame_id for frame_id in range(min_start, max_end) 87 | ]) 88 | inter_predgt = set( 89 | [frame_id for frame_id in range(max_start, min_end)] 90 | ) 91 | 92 | viou = 0 93 | gt_viou = 0 94 | prediction = predictions[video_id] 95 | 96 | for fid in self.vid2box[video_id].keys(): # iterate on all frames of the annotated moment to update GT metrics 97 | if fid not in prediction: 98 | raise RuntimeError(f"No prediction for frame {fid}") 99 | else: 100 | pred_boxes = prediction[fid] 101 | gt_boxes = self.vid2box[video_id][fid] 102 | iou = np_box_iou(np.array(pred_boxes), np.array(gt_boxes))[0][0] 103 | if fid in inter_predgt: 104 | viou += iou 105 | gt_viou += iou 106 | 107 | viou = viou / max(len(union_predgt), 1) 108 | vid_metrics[video_id]["viou"] = viou 109 | recalls = {thresh: 0 for thresh in self.iou_thresholds} 110 | for thresh in self.iou_thresholds: 111 | if viou > thresh: 112 | recalls[thresh] += 1 113 | vid_metrics[video_id].update( 114 | { 115 | f"viou@{thresh}": recalls[thresh] 116 | for thresh in self.iou_thresholds 117 | } 118 | ) 119 | 120 | # compute gt_viou@R 121 | gt_viou = gt_viou / max(len(self.vid2box[video_id]), 1) 122 | vid_metrics[video_id]["gt_viou"] = gt_viou 123 | gt_recalls = {thresh: 0 for thresh in self.iou_thresholds} 124 | for thresh in self.iou_thresholds: 125 | if gt_viou > thresh: 126 | gt_recalls[thresh] += 1 127 | vid_metrics[video_id].update( 128 | { 129 | f"gt_viou@{thresh}": gt_recalls[thresh] 130 | for thresh in self.iou_thresholds 131 | } 132 | ) 133 | 134 | return vid_metrics, self.vid2names, self.vid2sents 135 | 136 | 137 | class VidSTGEvaluator(object): 138 | def __init__( 139 | self, 140 | logger, 141 | vidstg_path, 142 | subset, 143 | iou_thresholds, 144 | save_pred=False, 145 | save_dir=None 146 | ): 147 | """ 148 | :param vidstg_path: path to VidSTG annotations 149 | :param subset: train, val or test 150 | :param iou_thresholds: IoU thresholds for the vIoU metrics 151 | :param save_pred: whether to save predictions in the output of summarize 152 | """ 153 | self.evaluator = VidSTGiouEvaluator( 154 | vidstg_path, 155 | subset=subset, 156 | iou_thresholds=iou_thresholds, 157 | ) 158 | self.predictions = {} 159 | self.fake_predictions = {} 160 | self.confs = {} 161 | self.video_predictions = {} 162 | self.video_cross_attn = {} 163 | self.results = None 164 | self.iou_thresholds = iou_thresholds 165 | self.save_pred = save_pred 166 | self.save_dir = save_dir 167 | self.logger = logger 168 | 169 | self.tsa_weights = {} 170 | self.text_weights = {} 171 | self.spatial_weights = {} 172 | self.pred_sted = {} 173 | 174 | def accumulate(self): 175 | pass 176 | 177 | def update(self, predictions): 178 | self.predictions.update(predictions) 179 | 180 | def fake_update(self, predictions): 181 | self.fake_predictions.update(predictions) 182 | 183 | def update_conf(self, confs): 184 | self.confs.update(confs) 185 | 186 | def update_cross_attn(self, cross_weights): 187 | self.video_cross_attn.update(cross_weights) 188 | 189 | def video_update(self, video_predictions): 190 | self.video_predictions.update(video_predictions) 191 | 192 | def synchronize_between_processes(self): 193 | all_predictions = all_gather(self.predictions) 194 | self.predictions = reduce(lambda a, b: a.update(b) or a, all_predictions, {}) 195 | all_predictions = all_gather(self.fake_predictions) 196 | self.fake_predictions = reduce(lambda a, b: a.update(b) or a, all_predictions, {}) 197 | all_confs = all_gather(self.confs) 198 | self.confs = reduce(lambda a, b: a.update(b) or a, all_confs, {}) 199 | all_video_predictions = all_gather(self.video_predictions) 200 | self.video_predictions = reduce(lambda a, b: a.update(b) or a, all_video_predictions, {}) 201 | 202 | def summarize(self): 203 | if is_main_process(): 204 | self.logger.info("####### Start Calculating the metrics ########") 205 | self.results, vid2names, vid2sents = self.evaluator.evaluate( 206 | self.predictions, self.video_predictions, self.confs 207 | ) 208 | categories = set(x["qtype"] for x in self.results.values()) 209 | metrics = {} 210 | counter = {} 211 | 212 | for category in categories: # init metrics 213 | metrics[category] = {"gt_viou": 0} 214 | metrics[category].update({"tiou": 0, "viou": 0}) 215 | for thresh in self.iou_thresholds: 216 | metrics[category][f"viou@{thresh}"] = 0 217 | metrics[category][f"gt_viou@{thresh}"] = 0 218 | counter[category] = 0 219 | 220 | for x in self.results.values(): # sum results 221 | qtype = x["qtype"] 222 | metrics[qtype]["tiou"] += x["tiou"] 223 | metrics[qtype]["viou"] += x["viou"] 224 | metrics[qtype]["gt_viou"] += x["gt_viou"] 225 | for thresh in self.iou_thresholds: 226 | metrics[qtype][f"viou@{thresh}"] += x[f"viou@{thresh}"] 227 | metrics[qtype][f"gt_viou@{thresh}"] += x[f"gt_viou@{thresh}"] 228 | counter[qtype] += 1 229 | 230 | result_str = '' 231 | result_str += '\n' + '=' * 100 + '\n' 232 | for category in categories: # average results per category 233 | for key in metrics[qtype]: 234 | metrics[category][key] = metrics[category][key] / counter[category] 235 | result_str += f"{category} {key}: {metrics[category][key]:.4f}" + '\n' 236 | 237 | result_str += '=' * 100 + '\n' 238 | self.logger.info(result_str) 239 | 240 | out = { 241 | f"{qtype}_{name}": metrics[qtype][name] 242 | for qtype in metrics 243 | for name in metrics[qtype] 244 | } 245 | 246 | if self.save_pred: 247 | out["predictions"] = self.predictions 248 | out["fake_predictions"] = self.fake_predictions 249 | out["confs"] = self.confs 250 | out["video_predictions"] = self.video_predictions 251 | out["vid_metrics"] = self.results 252 | out['vid2names'] = vid2names 253 | out['vid2sents'] = vid2sents 254 | res_path = os.path.join(self.save_dir,'test_results.json') 255 | save_json(res_path, out) 256 | 257 | return out 258 | 259 | return None 260 | -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.distributed 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from utils.box_utils import generalized_box_iou, box_cxcywh_to_xyxy, box_iou 8 | from utils.comm import is_dist_avail_and_initialized, get_world_size 9 | 10 | 11 | class VideoSTGLoss(nn.Module): 12 | """This class computes the loss for VideoSTG Model 13 | The process happens in two steps: 14 | 1) compute ground truth boxes and the outputs of the model 15 | 2) compute ground truth temporal segment and the outputs sted of model 16 | """ 17 | 18 | def __init__(self, cfg, losses): 19 | """Create the criterion. 20 | """ 21 | super().__init__() 22 | self.cfg = cfg 23 | self.losses = losses 24 | self.eos_coef = cfg.SOLVER.EOS_COEF 25 | 26 | def loss_boxes(self, outputs, targets, num_boxes): 27 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 28 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 29 | The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. 30 | """ 31 | assert "pred_boxes" in outputs 32 | 33 | src_boxes = outputs["pred_boxes"] 34 | target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0) 35 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") 36 | 37 | losses = {} 38 | losses["loss_bbox"] = loss_bbox.sum() / max(num_boxes, 1) 39 | 40 | loss_giou = 1 - torch.diag( 41 | generalized_box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) 42 | ) 43 | losses["loss_giou"] = loss_giou.sum() / max(num_boxes, 1) 44 | return losses 45 | 46 | def loss_conf(self, outputs, targets, num_boxes, gt_index): 47 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 48 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 49 | The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. 50 | """ 51 | assert "boxes_conf" in outputs 52 | losses = {} 53 | src_boxes = outputs["pred_boxes"] 54 | target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0) 55 | iou, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) 56 | iou = torch.diag(iou) 57 | conf = outputs['boxes_conf'][gt_index] 58 | # v1 and v2 59 | losses["loss_conf"] = nn.BCEWithLogitsLoss()(conf, iou) 60 | # vid 61 | # losses["loss_conf"] = F.smooth_l1_loss(conf, iou, reduction="none").sum() / max(num_boxes, 1) 62 | return losses 63 | 64 | def loss_actioness(self, outputs, targets, gt_temp_bound, time_mask=None): 65 | assert "pred_actioness" in outputs 66 | losses = {} 67 | pred_actioness = outputs['pred_actioness'].squeeze(-1) 68 | target_actioness = torch.stack([target["actioness"] for target in targets], dim=0).float() 69 | weight = torch.full(pred_actioness.shape, self.eos_coef, device=pred_actioness.device) 70 | 71 | for i_b in range(len(weight)): 72 | temp_bound = gt_temp_bound[i_b] 73 | weight[i_b][temp_bound[0] : temp_bound[1] + 1] = 1 74 | 75 | loss_actioness = F.binary_cross_entropy_with_logits(pred_actioness, \ 76 | target_actioness, weight=weight, reduction='none') 77 | 78 | loss_actioness = loss_actioness * time_mask 79 | losses["loss_actioness"] = loss_actioness.mean() 80 | return losses 81 | 82 | def loss_sted(self, outputs, num_boxes, gt_temp_bound, positive_map, time_mask=None): 83 | assert "pred_sted" in outputs 84 | sted = outputs["pred_sted"] 85 | losses = {} 86 | 87 | target_start = torch.tensor([x[0] for x in gt_temp_bound], dtype=torch.long).to(sted.device) 88 | target_end = torch.tensor([x[1] for x in gt_temp_bound], dtype=torch.long).to(sted.device) 89 | sted = sted.masked_fill(~time_mask[:, :, None], -1e32) # put very low probability on the padded positions before softmax 90 | eps = 1e-6 91 | 92 | sigma = self.cfg.SOLVER.SIGMA 93 | start_distrib = ( 94 | -( 95 | ( 96 | torch.arange(sted.shape[1])[None, :].to(sted.device) 97 | - target_start[:, None] 98 | ) 99 | ** 2 100 | ) 101 | / (2 * sigma ** 2) 102 | ).exp() # gaussian target 103 | start_distrib = F.normalize(start_distrib + eps, p=1, dim=1) 104 | pred_start_prob = (sted[:, :, 0]).softmax(1) 105 | loss_start = ( 106 | pred_start_prob * ((pred_start_prob + eps) / start_distrib).log() 107 | ) 108 | loss_start = loss_start * time_mask 109 | end_distrib = ( 110 | -( 111 | ( 112 | torch.arange(sted.shape[1])[None, :].to(sted.device) 113 | - target_end[:, None] 114 | ) 115 | ** 2 116 | ) 117 | / (2 * sigma ** 2) 118 | ).exp() # gaussian target 119 | end_distrib = F.normalize(end_distrib + eps, p=1, dim=1) 120 | pred_end_prob = (sted[:, :, 1]).softmax(1) 121 | loss_end = ( 122 | pred_end_prob * ((pred_end_prob + eps) / end_distrib).log() 123 | ) 124 | loss_end = loss_end * time_mask 125 | loss_sted = loss_start + loss_end 126 | losses["loss_sted"] = loss_sted.mean() 127 | return losses 128 | 129 | def loss_guided_attn( 130 | self, outputs, num_boxes, gt_temp_bound, positive_map, time_mask=None 131 | ): 132 | """Compute guided attention loss 133 | targets dicts must contain the key "weights" containing a tensor of attention matrices of dim [B, T, T] 134 | """ 135 | weights = outputs["weights"] # BxTxT 136 | 137 | positive_map = positive_map + (~time_mask) # the padded positions also have to be taken out 138 | eps = 1e-6 # avoid log(0) and division by 0 139 | 140 | loss = -(1 - weights + eps).log() 141 | loss = loss.masked_fill(positive_map[:, :, None], 0) 142 | nb_neg = (~positive_map).sum(1) + eps 143 | loss = loss.sum(2) / nb_neg[:, None] # sum on the column 144 | loss = loss.sum(1) # mean on the line normalized by the number of negatives 145 | loss = loss.mean() # mean on the batch 146 | 147 | losses = {"loss_guided_attn": loss} 148 | return losses 149 | 150 | def get_loss( 151 | self, loss, outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask, gt_bbox_slice, **kwargs, 152 | ): 153 | loss_map = { 154 | "boxes": self.loss_boxes, 155 | "sted": self.loss_sted, 156 | "guided_attn": self.loss_guided_attn, 157 | "actioness": self.loss_actioness, 158 | "conf": self.loss_conf 159 | } 160 | assert loss in loss_map, f"do you really want to compute {loss} loss?" 161 | if loss in ["sted", "guided_attn"]: 162 | return loss_map[loss]( 163 | outputs, num_boxes, gt_temp_bound, positive_map, time_mask, **kwargs 164 | ) 165 | if loss == "actioness": 166 | return loss_map[loss](outputs, targets, gt_temp_bound, time_mask, **kwargs) 167 | if loss == "conf": 168 | return loss_map[loss](outputs, targets, num_boxes, gt_bbox_slice) 169 | 170 | return loss_map[loss](outputs, targets, num_boxes, **kwargs) 171 | 172 | def forward(self, outputs, targets, durations): 173 | """This performs the loss computation. 174 | Parameters: 175 | outputs: dict of tensors, see the output specification of the model for the format 176 | targets: list of dicts, such that len(targets) == batch_size. 177 | The expected keys in each dict depends on the losses applied, see each loss' doc 178 | """ 179 | max_duration = max(durations) 180 | device = outputs["pred_boxes"].device 181 | gt_bbox_slice, gt_temp_bound = [], [] 182 | 183 | for i_dur, (duration, target) in enumerate(zip(durations, targets)): 184 | inter = torch.where(target['actioness'])[0].cpu().numpy().tolist() 185 | gt_temp_bound.append([inter[0],inter[-1]]) 186 | gt_bbox_slice.extend(list(range(i_dur * max_duration + inter[0], i_dur * max_duration + inter[-1] + 1))) 187 | 188 | gt_bbox_slice = torch.LongTensor(gt_bbox_slice).to(device) 189 | outputs["pred_boxes"] = outputs["pred_boxes"][gt_bbox_slice] 190 | 191 | # Compute the average number of target boxes accross all nodes, for normalization purposes 192 | num_boxes = sum(len(target['boxs']) for target in targets) 193 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) 194 | if is_dist_avail_and_initialized(): 195 | torch.distributed.all_reduce(num_boxes) 196 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 197 | 198 | # computer the temporal mask, used for guided-attn 199 | b = len(durations) 200 | time_mask = torch.zeros(b, max(durations)).bool().to(device) 201 | for i_dur, duration in enumerate(durations): 202 | time_mask[i_dur, :duration] = True 203 | 204 | positive_map = torch.zeros(time_mask.shape, dtype=torch.bool) 205 | for k, idx in enumerate(gt_temp_bound): 206 | if idx[0] < 0: # empty intersection 207 | continue 208 | positive_map[k][idx[0] : idx[1] + 1].fill_(True) 209 | 210 | positive_map = positive_map.to(time_mask.device) 211 | 212 | # Compute all the requested losses 213 | losses = {} 214 | for loss in self.losses: 215 | losses.update(self.get_loss(loss, outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask, gt_bbox_slice)) 216 | 217 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 218 | if "aux_outputs" in outputs: 219 | for i_aux in range(len(outputs["aux_outputs"])): 220 | outputs["aux_outputs"][i_aux]["pred_boxes"] = outputs["aux_outputs"][i_aux]["pred_boxes"][gt_bbox_slice] 221 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 222 | for loss in self.losses: 223 | kwargs = {} 224 | l_dict = self.get_loss(loss, aux_outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask, gt_bbox_slice, **kwargs) 225 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 226 | losses.update(l_dict) 227 | 228 | return losses -------------------------------------------------------------------------------- /datasets/hcstvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from copy import deepcopy 4 | import torch 5 | import random 6 | 7 | from tqdm import tqdm 8 | import torch.utils.data as data 9 | import numpy as np 10 | from PIL import Image 11 | import ffmpeg 12 | 13 | from torchvision.transforms import ToTensor, ToPILImage, Resize 14 | from utils.bounding_box import BoxList 15 | from .data_utils import make_hcstvg_input_clip 16 | 17 | 18 | import cv2 19 | import torch 20 | import numpy as np 21 | 22 | class HCSTVGDataset(data.Dataset): 23 | 24 | def __init__(self, cfg, split, transforms=None) -> None: 25 | super(HCSTVGDataset,self).__init__() 26 | assert split in ['train', 'test'] 27 | self.cfg = cfg.clone() 28 | self.split = split 29 | self.transforms = transforms 30 | 31 | self.data_dir = cfg.DATA_DIR 32 | self.anno_dir = os.path.join(self.data_dir,'annos/hcstvg_v2' if 'hc-stvg2' in self.data_dir else 'annos/hcstvg_v1') 33 | self.sent_file = os.path.join(self.anno_dir, f'{split}.json') # split 34 | self.epsilon = 1e-10 35 | 36 | self.all_gt_data = self.load_data() 37 | self.clean_miss() 38 | self.vocab = None 39 | 40 | if cfg.DATA_TRUNK is not None: 41 | self.all_gt_data = self.all_gt_data[:cfg.DATA_TRUNK] 42 | 43 | def clean_miss(self): 44 | miss_name = '10__Gvp-cj3bmIY.mp4' 45 | for item in self.all_gt_data: 46 | if item['vid'] == miss_name: 47 | self.all_gt_data.remove(item) 48 | break 49 | 50 | miss_name = '1_aMYcLyh9OhU.mkv' 51 | for item in self.all_gt_data: 52 | if item['vid'] == miss_name: 53 | self.all_gt_data.remove(item) 54 | break 55 | 56 | def get_video_info(self,index): 57 | video_info = {} 58 | data_item = self.all_gt_data[index] 59 | video_info['height'] = data_item['height'] 60 | video_info['width'] = data_item['width'] 61 | return video_info 62 | 63 | def load_frames(self, data_item, load_video=True): 64 | video_name = data_item['vid'] 65 | frame_ids = data_item['frame_ids'] 66 | patience = 20 67 | max_rate = 1.4 68 | 69 | if load_video: 70 | video_path = os.path.join(self.data_dir,'v2_video' if 'hc-stvg2' in self.data_dir else 'v1_video', video_name) 71 | h, w = data_item['height'], data_item['width'] 72 | succ_flag = False 73 | for _ in range(patience): 74 | try: 75 | out, _ = ( 76 | ffmpeg 77 | .input(video_path) 78 | .output('pipe:', format='rawvideo', pix_fmt='rgb24') 79 | .run(capture_stdout=True, quiet=True) 80 | ) 81 | frames = np.frombuffer(out, np.uint8).reshape([-1, h, w, 3]) 82 | succ_flag = True 83 | if succ_flag: 84 | break 85 | except Exception: 86 | # print(video_name) 87 | aa = 0 88 | 89 | if not succ_flag: 90 | print("video load wrong", video_path) 91 | frames = np.ones((1000, self.cfg.INPUT.RESOLUTION, int(self.cfg.INPUT.RESOLUTION*max_rate), 3), dtype=np.uint8) 92 | # raise RuntimeError("Load Video Error") 93 | try: 94 | frames = frames[frame_ids] 95 | except: 96 | print("frame_ids wrong", video_path) 97 | frames = np.ones((1000, self.cfg.INPUT.RESOLUTION, int(self.cfg.INPUT.RESOLUTION*max_rate), 3), dtype=np.uint8) 98 | frames = frames[frame_ids] 99 | 100 | rate = frames.shape[2] / frames.shape[1] 101 | frames = [Resize((self.cfg.INPUT.RESOLUTION, min(int(self.cfg.INPUT.RESOLUTION*rate), int(self.cfg.INPUT.RESOLUTION*max_rate))), antialias=True)(ToTensor()(frame)) for frame in frames] 102 | 103 | frames = torch.stack(frames) 104 | else: 105 | raise NotImplementedError("Not Implement load from frames") 106 | 107 | return frames 108 | 109 | def __getitem__(self, index: int): 110 | """ 111 | Usage: 112 | In training, sample a random clip from video 113 | In testing, chunk the video to a set of clips 114 | """ 115 | video_data = deepcopy(self.all_gt_data[index]) 116 | 117 | data_item = make_hcstvg_input_clip(self.cfg, self.split, video_data) 118 | 119 | frames = self.load_frames(data_item) # T * C * H * W 120 | 121 | # load the sampled gt bounding box 122 | frame_ids = data_item['frame_ids'] 123 | temp_gt = data_item['gt_temp_bound'] 124 | action_idx = np.where(data_item['actioness'])[0] 125 | start_idx, end_idx = action_idx[0], action_idx[-1] 126 | bbox_idx = [frame_ids[idx] - temp_gt[0] for idx in range(start_idx,end_idx + 1)] 127 | bboxs = torch.from_numpy(data_item['bboxs'][bbox_idx]).reshape(-1, 4) 128 | assert bboxs.shape[0] == len(action_idx) 129 | 130 | w, h = data_item['width'], data_item['height'] 131 | bboxs = BoxList(bboxs, (w, h), 'xyxy') 132 | 133 | sentence = data_item['description'] 134 | sentence = sentence.lower() 135 | input_dict = {'frames': frames, 'boxs': bboxs, 'text': sentence, \ 136 | 'actioness' : data_item['actioness']} 137 | 138 | if self.transforms is not None: 139 | input_dict = self.transforms(input_dict) 140 | 141 | targets = { 142 | 'item_id' : data_item['item_id'], 143 | 'frame_ids' : data_item['frame_ids'], 144 | 'actioness' : torch.from_numpy(data_item['actioness']) , 145 | 'start_heatmap' : torch.from_numpy(data_item['start_heatmap']), 146 | 'end_heatmap' : torch.from_numpy(data_item['end_heatmap']), 147 | 'boxs' : input_dict['boxs'], 148 | 'img_size' : input_dict['frames'].shape[2:], 149 | 'ori_size' : (h, w) 150 | } 151 | 152 | return input_dict['frames'], sentence, targets 153 | 154 | def __len__(self) -> int: 155 | return len(self.all_gt_data) 156 | 157 | def load_data(self): 158 | """ 159 | Prepare the Input Data Cache and the evaluation data groundtruth 160 | """ 161 | cache_dir = os.path.join(self.data_dir,'data_cache') 162 | if not os.path.exists(cache_dir): 163 | os.makedirs(cache_dir) 164 | 165 | # Used for Model Input 166 | dataset_cache = os.path.join(cache_dir, f'hcstvg-{self.split}-input.cache') 167 | # Used For Evaluateion 168 | gt_anno_cache = os.path.join(cache_dir, f'hcstvg-{self.split}-anno.cache') 169 | 170 | if os.path.exists(dataset_cache): 171 | data = torch.load(dataset_cache) 172 | return data 173 | 174 | gt_data, gt_anno = [], [] 175 | vstg_anno = self.preprocess(self.sent_file) 176 | 177 | for anno_id in tqdm(vstg_anno): 178 | gt_file = vstg_anno[anno_id] 179 | frame_nums = gt_file['frame_count'] 180 | video_name = gt_file['vid'] 181 | 182 | start_fid = 0 183 | end_fid = frame_nums - 1 184 | temp_gt_begin = max(0, gt_file['tube_start_frame']) 185 | temp_gt_end = min(gt_file['tube_end_frame'], end_fid) 186 | 187 | assert len(gt_file['target_bboxs']) == temp_gt_end - temp_gt_begin + 1 188 | 189 | frame_ids = [] 190 | for frame_id in range(start_fid, end_fid): 191 | frame_ids.append(frame_id) 192 | 193 | actioness = np.array([int(fid <= temp_gt_end and fid >= temp_gt_begin) for fid in frame_ids]) 194 | 195 | # prepare the temporal heatmap 196 | action_idx = np.where(actioness)[0] 197 | start_idx, end_idx = action_idx[0], action_idx[-1] 198 | 199 | start_heatmap = np.ones(actioness.shape) * self.epsilon 200 | pesudo_prob = (1 - (start_heatmap.shape[0] - 3) * self.epsilon - 0.5) / 2 201 | 202 | start_heatmap[start_idx] = 0.5 203 | if start_idx > 0: 204 | start_heatmap[start_idx-1] = pesudo_prob 205 | if start_idx < actioness.shape[0] - 1: 206 | start_heatmap[start_idx+1] = pesudo_prob 207 | 208 | end_heatmap = np.ones(actioness.shape) * self.epsilon 209 | end_heatmap[end_idx] = 0.5 210 | if end_idx > 0: 211 | end_heatmap[end_idx-1] = pesudo_prob 212 | if end_idx < actioness.shape[0] - 1: 213 | end_heatmap[end_idx+1] = pesudo_prob 214 | 215 | bbox_array = [] 216 | for idx in range(len(gt_file['target_bboxs'])): 217 | bbox = gt_file['target_bboxs'][idx] 218 | x1, y1, w, h = bbox 219 | bbox_array.append(np.array([x1,y1,min(x1+w, gt_file['width']), min(y1+h, gt_file['height'])])) 220 | assert x1 <= gt_file['width'] and x1 + w <= gt_file['width'] 221 | assert y1 <= gt_file['height'] and y1 + h <= gt_file['height'] 222 | 223 | bbox_array = np.array(bbox_array) 224 | assert bbox_array.shape[0] == temp_gt_end - temp_gt_begin + 1 225 | 226 | gt_bbox_dict = {fid : bbox_array[fid - temp_gt_begin].tolist() \ 227 | for fid in range(temp_gt_begin, temp_gt_end + 1)} 228 | 229 | gt_item = { 230 | 'item_id' : gt_file['id'], 231 | 'vid' : video_name, 232 | 'bboxs' : gt_bbox_dict, 233 | 'description' : gt_file['sentence'], 234 | 'gt_temp_bound' : [temp_gt_begin, temp_gt_end], 235 | 'frame_count' : gt_file['frame_count'] 236 | } 237 | 238 | item = { 239 | 'item_id' : gt_file['id'], 240 | 'vid' : video_name, 241 | 'frame_ids' : frame_ids, 242 | 'width' : gt_file['width'], 243 | 'height' : gt_file['height'], 244 | 'start_heatmap': start_heatmap, 245 | 'end_heatmap': end_heatmap, 246 | 'actioness': actioness, 247 | 'bboxs' : bbox_array, 248 | 'gt_temp_bound' : [temp_gt_begin, temp_gt_end], 249 | 'description' : gt_file['sentence'], 250 | 'object' : 'person', 251 | 'frame_count' : gt_file['frame_count'] 252 | } 253 | 254 | gt_data.append(item) 255 | gt_anno.append(gt_item) 256 | 257 | random.shuffle(gt_data) 258 | torch.save(gt_data, dataset_cache) 259 | torch.save(gt_anno, gt_anno_cache) 260 | return gt_data 261 | 262 | def preprocess(self,anno_file): 263 | """ 264 | preoprocess from the original annotation 265 | """ 266 | pair_cnt = 0 267 | print(f"Prepare {self.split} Data") 268 | 269 | with open(anno_file, 'r') as fr: 270 | hcstvg_anno = json.load(fr) 271 | 272 | proc_hcstvg_anno = {} 273 | for vid in tqdm(hcstvg_anno): 274 | anno = hcstvg_anno[vid] 275 | data_pairs = {} 276 | data_pairs['vid'] = vid 277 | data_pairs['width'] = anno['width'] 278 | data_pairs['height'] = anno['height'] 279 | data_pairs['frame_count'] = anno['img_num'] 280 | data_pairs['tube_start_frame'] = anno['st_frame'] - 1 281 | data_pairs['tube_end_frame'] = data_pairs['tube_start_frame'] + len(anno['bbox']) - 1 282 | data_pairs['tube_start_time'] = anno['st_time'] 283 | data_pairs['tube_end_time'] = anno['ed_time'] 284 | data_pairs['id'] = pair_cnt 285 | data_pairs['sentence'] = anno['caption'] 286 | data_pairs['target_bboxs'] = anno['bbox'] 287 | proc_hcstvg_anno[pair_cnt] = data_pairs 288 | pair_cnt += 1 289 | 290 | print(f'{self.split} pair number : {pair_cnt}') 291 | return proc_hcstvg_anno 292 | --------------------------------------------------------------------------------