├── .gitignore ├── README.md ├── RF_ActionVOS ├── __init__.py ├── actionvos.py ├── configs.md ├── criterion.py ├── inference_actionvos.py ├── main_actionvos.py ├── opts.py ├── referformer.py ├── segmentation.py ├── test_actionvos.sh ├── train_actionvos.sh └── transforms_video_actionvos.py ├── actionvos_metrics.py ├── annotations ├── 00000.png ├── EPIC_100_train.csv └── EPIC_100_validation.csv ├── copy_rf_actionvos_files.py ├── data_prepare_visor.py ├── data_prepare_vost.py ├── data_prepare_vscos.py ├── dataset_visor └── ImageSets │ ├── val_human.json │ └── val_novel.json ├── demo_path ├── ImageSets │ └── expression_file.json └── JPEGImages_Sparse │ └── val │ ├── 00000012_P01_107_put-into_bag:cereal │ ├── frame_0000002521.jpg │ └── frame_0000002559.jpg │ └── 00000223_P02_09_pick-up_spoon │ ├── frame_0000040843.jpg │ └── frame_0000040873.jpg └── figures ├── ActionVOS.png ├── method.png └── weights.png /.gitignore: -------------------------------------------------------------------------------- 1 | annotations/visor_hos_train.json 2 | annotations/visor_hos_val.json 3 | dataset_visor/ImageSets/train.json 4 | dataset_visor/ImageSets/val.json 5 | dataset_visor/ImageSets/train_objects_category.json 6 | dataset_visor/ImageSets/val_objects_category.json 7 | dataset_visor/ImageSets/train_meta_expressions_promptaction.json 8 | dataset_visor/ImageSets/val_meta_expressions_promptaction.json 9 | ReferFormer/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **ActionVOS: Actions as Prompts for Video Object Segmentation** 2 | 3 | Our [paper](https://arxiv.org/abs/2407.07402) is accepted by **ECCV-2024** as [**oral**](https://eccv.ecva.net/virtual/2024/oral/1604) **(2.3%)** presentation! 4 | 5 |
ActionVOS
6 | 7 | **Picture:** *Overview of the proposed ActionVOS setting.* 8 | 9 |
method
10 | 11 | **Picture:** *The proposed method in our paper.* 12 | 13 | --- 14 | 15 | This repository contains the official PyTorch implementation of the following paper: 16 | 17 | > **ActionVOS: Actions as Prompts for Video Object Segmentation**
18 | Liangyang Ouyang, Ruicong Liu, Yifei Huang, Ryosuke Furuta, and Yoichi Sato
19 | > 20 | >**Abstract:** 21 | Delving into the realm of egocentric vision, the advancement of referring video object segmentation (RVOS) stands as pivotal in understanding human activities. However, existing RVOS task primarily relies on static attributes such as object names to segment target objects, posing challenges in distinguishing target objects from background objects and in identifying objects undergoing state changes. To address these problems, this work proposes a novel action-aware RVOS setting called ActionVOS, aiming at segmenting only active objects in egocentric videos using human actions as a key language prompt. This is because human actions precisely describe the behavior of humans, thereby helping to identify the objects truly involved in the interaction and to understand possible state changes. We also build a method tailored to work under this specific setting. Specifically, we develop an action-aware labeling module with an efficient action-guided focal loss. Such designs enable ActionVOS model to prioritize active objects with existing readily-available annotations. Experimental results on VISOR dataset reveal that ActionVOS significantly reduces the mis-segmentation of inactive objects, confirming that actions help the ActionVOS model understand objects' involvement. Further evaluations on VOST and VSCOS datasets show that the novel ActionVOS setting enhances segmentation performance when encountering challenging circumstances involving object state changes. 22 | 23 | ## Resources 24 | 25 | Material related to our paper is available via the following links: 26 | 27 | - [**Paper**](https://arxiv.org/abs/2407.07402) 28 | - [**Video**](https://youtu.be/dt-zDQKzq1I) 29 | - [VISOR dataset](https://epic-kitchens.github.io/VISOR/) 30 | - [VOST dataset](https://www.vostdataset.org/data.html) 31 | - [VSCOS dataset](https://github.com/venom12138/VSCOS) 32 | - [ReferFormer Model](https://github.com/wjn922/ReferFormer) 33 | 34 | ## Requirements 35 | 36 | * Our experiment is tested with Python 3.8, PyTorch 1.11.0. 37 | * Our experiment with RerferFormer used 4 V100 GPUs, and 6-12 hours for train 6 epochs on VISOR. 38 | * Check **Training** instructions for necessary packages of RF. 39 | 40 | ## Playing with ActionVOS 41 | 42 | ### **Data preparation (Pseudo-labeling and Weight-generation)** 43 | 44 | For the videos and masks, please download VISOR-VOS,VSCOS,VOST dataset from these links. We recommend to download VISOR-VOS first since we use VISOR-VOS for both training and testing. 45 | 46 | - [**VISOR-VOS (28.4GB)**](https://data.bris.ac.uk/data/dataset/2v6cgv1x04ol22qp9rm9x2j6a7) 47 | - [VSCOS (20GB)](https://github.com/venom12138/VSCOS) 48 | - [VOST (50GB)](https://www.vostdataset.org/data.html) 49 | 50 | [Action narration annotations](./annotations/EPIC_100_train.csv) are obtained from [EK-100](https://github.com/epic-kitchens/epic-kitchens-100-annotations). (We already put them in this repository so you don't need to download it.) 51 | 52 | [Hand-object annotations](./annotations/visor_hos_train.json) are obtained from [VISOR-HOS](https://github.com/epic-kitchens/VISOR-HOS). (Please download from google drive [link1](https://drive.google.com/file/d/1Op-QtoweJ-2M0nuMqtbBHAsJ4Ep-g6nU/view?usp=sharing), [link2](https://drive.google.com/file/d/1KkQ-BOC4E0P087D2hyTN9eUxMmNPq_Ot/view?usp=sharing) and put them under /annotations.) 53 | 54 | Then run data_prepare_visor.py to get data,annotation,action-aware pseudo-labels and action-guided weights for ActionVOS. 55 | 56 | ``` 57 | python data_prepare_visor.py --VISOR_PATH your_visor_epick_path 58 | ``` 59 | 60 | It takes 1-2 hours for processing data. After that, the folder dataset_visor will get structure of: 61 | 62 | ``` 63 | - dataset_visor 64 | - Annotations_Sparse 65 | - train 66 | - 00000001_xxx 67 | - obj_masks.png 68 | - 00000002_xxx 69 | - val 70 | - JPEGImages_Sparse 71 | - train 72 | - 00000001_xxx 73 | - rgb_frames.jpg 74 | - 00000002_xxx 75 | - val 76 | - Weights_Sparse 77 | - train 78 | - 00000001_xxx 79 | - action-guided-weights.png 80 | - 00000002_xxx 81 | - val (not used) 82 | - ImageSets 83 | - train.json 84 | - val.json 85 | - val_human.json 86 | - val_novel.json 87 | ``` 88 | 89 | There are 2 special files val_human.json and val_novel.json. These files contains the split that used for results in our experiments, where val_human contains the actions annotated by human, val_novel contains actions that unseen in the validation set. 90 | 91 | ### **How to find action-aware pseudo labels** 92 | 93 | Check [train.json](./dataset_visor/ImageSets/train.json). For each object name in each video, the json file contains a map such as {"name": "food container", "class_id": 21, "handbox": 0, "narration": 1, "positive": 1}. 94 | 95 | handbox = 1 for object mask intersects with hand-object bounding boxes. 96 | 97 | narration = 1 for object name mentioned in action narration. 98 | 99 | positive = 1 for pseudo positive object. 100 | 101 | Note that object masks under Annotations_Sparse are for all objects. We combine them with class labels in experiments. 102 | 103 | ### **How to find action-guided weights** 104 | 105 | Each picture under Weights_Sparse is an action-guided weight. 106 | 107 |
weights
108 | 109 | **Picture:** *Action-guided Weights* 110 | 111 | ``` 112 | 3 (yellow) for negative obj mask. 113 | 2 (green) for hand | narration obj mask. 114 | 4 (blue) for hand & narration obj mask. 115 | 1 (red) for other areas 116 | ``` 117 | 118 | ### **Training** 119 | 120 | ActionVOS is an action-aware setting for RVOS, and any RVOS model with an extra class head can be trained for ActionVOS. In our experiments, we take ReferFormer-ResNet101 as the base RVOS model. 121 | 122 | Clone [ReferFormer](https://github.com/wjn922/ReferFormer) repository and download their [pretrained checkpoints](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wjn922_connect_hku_hk/EShgDd650nBBsfoNEiUbybcB84Ma5NydxOucISeCrZmzHw?e=YOSszd). 123 | ``` 124 | git clone https://github.com/wjn922/ReferFormer.git 125 | cd ReferFormer 126 | mkdir pretrained_weights 127 | download from the link 128 | ``` 129 | 130 | Install the necessary packages for ReferFormer. 131 | 132 | ``` 133 | cd ReferFormer 134 | pip install -r requirements.txt 135 | pip install 'git+https://github.com/facebookresearch/fvcore' 136 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 137 | cd models/ops 138 | python setup.py build install 139 | ``` 140 | 141 | Put modificated files to ReferFormer folders. 142 | 143 | ``` 144 | python copy_rf_actionvos_files.py 145 | ``` 146 | 147 | Run training scripts. If you want to change training configs, check [RF_ActionVOS/configs.md](RF_ActionVOS/configs.md). The following example shows training actionvos on a single GPU 0. 148 | 149 | ``` 150 | cd ReferFormer 151 | bash scripts/train_actionvos.sh actionvos_dirs/r101 pretrained_weights/r101_refytvos_joint.pth 1 0 29500 --backbone resnet101 --expression_file train_meta_expressions_promptaction.json --use_weights --use_positive_cls --actionvos_path ../dataset_visor --epochs 6 --lr_drop 3 5 --save_interval 3 152 | ``` 153 | 154 | After the training process, the weights will be saved to actionvos_dirs/r101/checkpoint.pth. 155 | 156 | ### **Inference** 157 | 158 | For quick start to ActionVOS models, we offer a trained RF-R101 checkpoint in [this link](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing). 159 | 160 | #### **Inference on VISOR** 161 | 162 | ``` 163 | cd ReferFormer 164 | bash scripts/test_actionvos.sh actionvos_dirs/r101 pretrained_weights/actionvos_rf_r101.pth 0 29500 --backbone resnet101 --expression_file val_meta_expressions_promptaction.json --use_positive_cls --pos_cls_thres 0.75 --actionvos_path ../dataset_visor 165 | ``` 166 | 167 | The output masks will be saved in ReferFormer/actionvos_dirs/r101/val. 168 | 169 | #### **Inference on your own videos and prompts** 170 | 171 | Change your videos and prompts into a actionvos_path like 172 | 173 | ``` 174 | - demo_path 175 | - JPEGImages_Sparse 176 | - val 177 | - video_name 178 | - rgb_frames.jpg 179 | - ImageSets 180 | - expression_file.json 181 | ``` 182 | 183 | Check the [example json file](demo_path/ImageSets/expression_file.json) for the prompt formats. 184 | 185 | ``` 186 | cd ReferFormer 187 | bash scripts/test_actionvos.sh actionvos_dirs/demo pretrained_weights/actionvos_rf_r101.pth 0 29500 --backbone resnet101 --expression_file expression_file.json --use_positive_cls --pos_cls_thres 0.75 --actionvos_path ../demo_path 188 | ``` 189 | 190 | The output masks will be saved in ReferFormer/actionvos_dirs/demo/val. 191 | 192 | #### **Evaluation Metrics** 193 | 194 | We use 6 metrics, p-mIoU, n-mIoU, p-cIoU, n-cIoU, gIoU and accuracy to evaluate ActionVOS performance on VISOR val_human split. 195 | 196 | ``` 197 | python actionvos_metrics.py --pred_path ReferFormer/actionvos_dirs/r101/val --gt_path dataset_visor/Annotations_Sparse/val --split_json dataset_visor/ImageSets/val_human.json 198 | ``` 199 | 200 | If you correctly generated object masks by [this checkpoint](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing), you should get results below: 201 | 202 | | Model | Split | p-mIoU | n-mIoU | p-cIoU | n-cIoU | gIoU | Acc | 203 | |----------|-----------|-----------|---------|-----------|-----------|---------|-----------| 204 | | [RF_R101](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing) | val_human* | 66.1 | 18.6 | 72.7 | 32.2 | 71.2 | 83.0 | 205 | 206 | \* Note that the val_human here only use 294 videos. Check [actionvos_metrics.py](actionvos_metrics.py) for details. 207 | 208 | ## Citation 209 | 210 | If this work or code is helpful in your research, please cite: 211 | 212 | ```latex 213 | @inproceedings{ouyang2024actionvos, 214 | title={ActionVOS: Actions as Prompts for Video Object Segmentation}, 215 | author={Ouyang, Liangyang and Liu, Ruicong and Huang, Yifei and Furuta, Ryosuke and Sato, Yoichi}, 216 | booktitle={European Conference on Computer Vision}, 217 | pages={216--235}, 218 | year={2024} 219 | } 220 | ``` 221 | 222 | If you are using the data and annotations from [VISOR](https://proceedings.neurips.cc/paper_files/paper/2022/hash/590a7ebe0da1f262c80d0188f5c4c222-Abstract-Datasets_and_Benchmarks.html),[VSCOS](https://openaccess.thecvf.com/content/ICCV2023/html/Yu_Video_State-Changing_Object_Segmentation_ICCV_2023_paper.html),[VOST](https://openaccess.thecvf.com/content/CVPR2023/html/Tokmakov_Breaking_the_Object_in_Video_Object_Segmentation_CVPR_2023_paper.html), please cite their original paper. 223 | 224 | If you are using the training, inference and evaluation code, please cite [ReferFormer](https://openaccess.thecvf.com/content/CVPR2022/html/Wu_Language_As_Queries_for_Referring_Video_Object_Segmentation_CVPR_2022_paper.html) and [GRES](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GRES_Generalized_Referring_Expression_Segmentation_CVPR_2023_paper.html). 225 | 226 | 227 | ## Contact 228 | 229 | For any questions, including algorithms and datasets, feel free to contact me by email: `oyly(at)iis.u-tokyo.ac.jp` -------------------------------------------------------------------------------- /RF_ActionVOS/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .ytvos import build as build_ytvos 5 | from .davis import build as build_davis 6 | from .a2d import build as build_a2d 7 | from .jhmdb import build as build_jhmdb 8 | from .refexp import build as build_refexp 9 | from .concat_dataset import build as build_joint 10 | from .actionvos import build as build_actionvos 11 | # from .actionvos_allpos import build as build_actionvos_allpos 12 | # from .actionvos_state import build as build_state 13 | 14 | def get_coco_api_from_dataset(dataset): 15 | for _ in range(10): 16 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 17 | # break 18 | if isinstance(dataset, torch.utils.data.Subset): 19 | dataset = dataset.dataset 20 | if isinstance(dataset, torchvision.datasets.CocoDetection): 21 | return dataset.coco 22 | 23 | 24 | def build_dataset(dataset_file: str, image_set: str, args): 25 | if dataset_file == 'ytvos': 26 | return build_ytvos(image_set, args) 27 | if dataset_file == 'davis': 28 | return build_davis(image_set, args) 29 | if dataset_file == 'a2d': 30 | return build_a2d(image_set, args) 31 | if dataset_file == 'jhmdb': 32 | return build_jhmdb(image_set, args) 33 | # for pretraining 34 | if dataset_file == "refcoco" or dataset_file == "refcoco+" or dataset_file == "refcocog": 35 | return build_refexp(dataset_file, image_set, args) 36 | # for joint training of refcoco and ytvos 37 | if dataset_file == 'joint': 38 | return build_joint(image_set, args) 39 | if dataset_file == 'actionvos': 40 | return build_actionvos(image_set, args) 41 | if dataset_file == 'actionvos_allpos': 42 | return build_actionvos_allpos(image_set, args) 43 | if dataset_file == 'vost': 44 | return build_state(image_set, args, 'vost') 45 | if dataset_file == 'vscos': 46 | return build_state(image_set, args, 'vscos') 47 | raise ValueError(f'dataset {dataset_file} not supported') 48 | -------------------------------------------------------------------------------- /RF_ActionVOS/actionvos.py: -------------------------------------------------------------------------------- 1 | """ 2 | actionvos data loader 3 | Note that we adjust the transform file (for data augmentation) 4 | # TODO check possible bug when box [0,0,0,0] goes to augmentation. 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | from torch.autograd.grad_mode import F 10 | from torch.utils.data import Dataset 11 | import datasets.transforms_video_actionvos as T 12 | 13 | import os 14 | from PIL import Image 15 | import json 16 | import numpy as np 17 | import random 18 | 19 | #from datasets.categories import ytvos_category_dict as category_dict 20 | 21 | 22 | class ActionVOSDataset(Dataset): 23 | """ 24 | In this version, sampling every triplet 25 | if the object is negative, the mask would be all zero 26 | """ 27 | def __init__(self, actionvos_folder: Path, ann_file: Path, transforms, return_masks: bool, 28 | num_frames: int, max_skip: int, use_weights: bool, image_set: str): 29 | self.actionvos_folder = actionvos_folder 30 | self.ann_file = ann_file 31 | self._transforms = transforms 32 | self.return_masks = return_masks # not used 33 | self.num_frames = num_frames 34 | self.max_skip = max_skip 35 | self.use_weights = use_weights 36 | self.image_set = image_set 37 | # create video meta data 38 | self.prepare_metas() 39 | 40 | print('\n video num: ', len(self.videos), ' clip num: ', len(self.metas)) 41 | print('\n') 42 | 43 | def prepare_metas(self): 44 | # read object information 45 | with open(os.path.join(str(self.actionvos_folder),'ImageSets', f'{self.image_set}_objects_category.json'), 'r') as f: 46 | subset_metas_by_video = json.load(f)['videos'] 47 | 48 | # read expression data 49 | with open(os.path.join(str(self.actionvos_folder),'ImageSets', self.ann_file), 'r') as f: 50 | subset_expressions_by_video = json.load(f)['videos'] 51 | self.videos = list(subset_expressions_by_video.keys()) 52 | 53 | self.metas = [] 54 | for vid in self.videos: 55 | vid_meta = subset_metas_by_video[vid] 56 | vid_data = subset_expressions_by_video[vid] 57 | vid_frames = sorted(vid_data['frames']) 58 | vid_len = len(vid_frames) 59 | for exp_id, exp_dict in vid_data['expressions'].items(): 60 | for frame_id in range(0, vid_len, self.num_frames): 61 | meta = {} 62 | meta['video'] = vid 63 | meta['exp'] = exp_dict['exp'] 64 | meta['obj_id'] = int(exp_dict['obj_id']) 65 | meta['frames'] = vid_frames 66 | meta['frame_id'] = frame_id 67 | # get object category and pos 68 | obj_id = exp_dict['obj_id'] 69 | meta['category'] = vid_meta['objects'][obj_id]['category'] 70 | meta['class_id'] = exp_dict['class_id'] 71 | meta['positive'] = exp_dict['positive'] 72 | self.metas.append(meta) 73 | #self.metas = self.metas[:10]# for debug 74 | 75 | @staticmethod 76 | def bounding_box(img): 77 | rows = np.any(img, axis=1) 78 | cols = np.any(img, axis=0) 79 | rmin, rmax = np.where(rows)[0][[0, -1]] 80 | cmin, cmax = np.where(cols)[0][[0, -1]] 81 | return rmin, rmax, cmin, cmax # y1, y2, x1, x2 82 | 83 | def __len__(self): 84 | return len(self.metas) 85 | 86 | def __getitem__(self, idx): 87 | instance_check = False 88 | while not instance_check: 89 | meta = self.metas[idx] # dict 90 | 91 | video, exp, obj_id, category, positive, frames, frame_id = \ 92 | meta['video'], meta['exp'], meta['obj_id'], meta['category'], meta['positive'], meta['frames'], meta['frame_id'] 93 | # clean up the caption 94 | exp = " ".join(exp.lower().split()) 95 | category_id = meta['class_id'] 96 | vid_len = len(frames) 97 | 98 | num_frames = self.num_frames 99 | # random sparse sample 100 | sample_indx = [frame_id] 101 | if self.num_frames != 1: 102 | # local sample 103 | sample_id_before = random.randint(1, 3) 104 | sample_id_after = random.randint(1, 3) 105 | local_indx = [max(0, frame_id - sample_id_before), min(vid_len - 1, frame_id + sample_id_after)] 106 | sample_indx.extend(local_indx) 107 | 108 | # global sampling 109 | if num_frames > 3: 110 | all_inds = list(range(vid_len)) 111 | global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):] 112 | global_n = num_frames - len(sample_indx) 113 | if len(global_inds) > global_n: 114 | select_id = random.sample(range(len(global_inds)), global_n) 115 | for s_id in select_id: 116 | sample_indx.append(global_inds[s_id]) 117 | elif vid_len >=global_n: # sample long range global frames 118 | select_id = random.sample(range(vid_len), global_n) 119 | for s_id in select_id: 120 | sample_indx.append(all_inds[s_id]) 121 | else: 122 | select_id = random.sample(range(vid_len), global_n - vid_len) + list(range(vid_len)) 123 | for s_id in select_id: 124 | sample_indx.append(all_inds[s_id]) 125 | sample_indx.sort() 126 | 127 | # read frames and masks and weights 128 | imgs, labels, boxes, masks, valid = [], [], [], [], [] 129 | positives = [] 130 | weights = [] 131 | for j in range(self.num_frames): 132 | frame_indx = sample_indx[j] 133 | frame_name = frames[frame_indx] 134 | img_path = os.path.join(str(self.actionvos_folder), 'JPEGImages_Sparse', self.image_set, video, frame_name + '.jpg') 135 | mask_path = os.path.join(str(self.actionvos_folder), 'Annotations_Sparse', self.image_set, video, frame_name + '.png') 136 | img = Image.open(img_path).convert('RGB') 137 | mask = Image.open(mask_path).convert('P') 138 | if self.use_weights: 139 | weight_path = os.path.join(str(self.actionvos_folder), 'Weights_Sparse', self.image_set, video, frame_name + '.png') 140 | weight = Image.open(weight_path).convert('P') 141 | weight = np.array(weight) 142 | 143 | # create the target 144 | label = torch.tensor(category_id) 145 | mask = np.array(mask) 146 | 147 | if self.use_weights: 148 | # for where weight == 3, mask should be zero 149 | # for where mask == 0, final weight should be 1 150 | mask = (mask==obj_id).astype(np.float32) # 0,1 binary 151 | weight = np.where(mask == 0, 1, weight) 152 | mask = np.where(weight == 3, 0, mask) 153 | # we could map weight to other numbers 154 | # by statistic of the training set, we set 5 to 3 and 4, 2 to hand obj 155 | weight = np.where(weight>=3, 5, weight) 156 | if positive: 157 | if (mask > 0).any(): 158 | y1, y2, x1, x2 = self.bounding_box(mask) 159 | box = torch.tensor([x1, y1, x2, y2]).to(torch.float) 160 | else: # some frame didn't contain the instance 161 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 162 | else: 163 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 164 | mask = np.zeros_like(mask) 165 | else: 166 | # check positive 167 | if positive: 168 | mask = (mask==obj_id).astype(np.float32) # 0,1 binary 169 | if (mask > 0).any(): 170 | y1, y2, x1, x2 = self.bounding_box(mask) 171 | box = torch.tensor([x1, y1, x2, y2]).to(torch.float) 172 | else: # some frame didn't contain the instance 173 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 174 | else: 175 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 176 | mask = np.zeros_like(mask) 177 | weight = np.ones_like(mask) 178 | mask = torch.from_numpy(mask) 179 | weight = torch.from_numpy(weight) 180 | 181 | # append 182 | imgs.append(img) 183 | labels.append(label) 184 | masks.append(mask) 185 | boxes.append(box) 186 | valid.append(1) 187 | positives.append(positive) 188 | weights.append(weight) 189 | 190 | # transform 191 | w, h = img.size 192 | labels = torch.stack(labels, dim=0) 193 | boxes = torch.stack(boxes, dim=0) 194 | boxes[:, 0::2].clamp_(min=0, max=w) 195 | boxes[:, 1::2].clamp_(min=0, max=h) 196 | masks = torch.stack(masks, dim=0) 197 | weights = torch.stack(weights, dim=0) 198 | target = { 199 | 'frames_idx': torch.tensor(sample_indx), # [T,] 200 | 'labels': labels, # [T,] 201 | 'boxes': boxes, # [T, 4], xyxy 202 | 'masks': masks, # [T, H, W] 203 | 'weights': weights, # [T, H, W] 204 | 'valid': torch.tensor(valid), # [T,] 205 | 'positive': torch.tensor(positives), # [T,] 206 | 'caption': exp, 207 | 'orig_size': torch.as_tensor([int(h), int(w)]), 208 | 'size': torch.as_tensor([int(h), int(w)]) 209 | } 210 | 211 | # "boxes" normalize to [0, 1] and transform from xyxy to cxcywh in self._transform 212 | imgs, target = self._transforms(imgs, target) 213 | imgs = torch.stack(imgs, dim=0) # [T, 3, H, W] 214 | 215 | # FIXME: handle "valid", since some box may be removed due to random crop 216 | # skip this. we sample all negative samples 217 | instance_check = True 218 | ''' 219 | if torch.any(target['valid'] == 1): # at leatst one instance 220 | instance_check = True 221 | else: 222 | idx = random.randint(0, self.__len__() - 1) 223 | ''' 224 | return imgs, target 225 | 226 | 227 | def make_coco_transforms(image_set, max_size=640): 228 | normalize = T.Compose([ 229 | T.ToTensor(), 230 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 231 | ]) 232 | 233 | scales = [288, 320, 352, 392, 416, 448, 480, 512] 234 | 235 | if image_set == 'train': 236 | # edited by oyly, do not check, and transform weights 237 | return T.Compose([ 238 | T.RandomHorizontalFlip(), 239 | T.PhotometricDistort(), 240 | T.RandomSelect( 241 | T.Compose([ 242 | T.RandomResize(scales, max_size=max_size), 243 | #T.Check(), 244 | ]), 245 | T.Compose([ 246 | T.RandomResize([400, 500, 600]), 247 | T.RandomSizeCrop(384, 600), 248 | T.RandomResize(scales, max_size=max_size), 249 | #T.Check(), 250 | ]) 251 | ), 252 | normalize, 253 | ]) 254 | 255 | # we do not use the 'val' set since the annotations are inaccessible 256 | if image_set == 'val' or image_set == 'test': 257 | return T.Compose([ 258 | T.RandomResize([360], max_size=640), 259 | normalize, 260 | ]) 261 | 262 | raise ValueError(f'unknown {image_set}') 263 | 264 | 265 | def build(image_set, args): 266 | root = Path(args.actionvos_path) 267 | assert root.exists(), f'provided ActionVOS path {root} does not exist' 268 | ann_file = args.expression_file 269 | print('you are building actionvos {} set with {} , {}'.format(image_set,args.actionvos_path,ann_file)) 270 | dataset = ActionVOSDataset(args.actionvos_path, ann_file, transforms=make_coco_transforms(image_set, max_size=args.max_size), return_masks=args.masks, 271 | num_frames=args.num_frames, max_skip=args.max_skip, use_weights=args.use_weights, image_set=image_set) 272 | return dataset 273 | -------------------------------------------------------------------------------- /RF_ActionVOS/configs.md: -------------------------------------------------------------------------------- 1 | **Instructions of ReferFormer Configs** 2 | 3 | ``` 4 | bash scripts/train_actionvos.sh actionvos_dirs/r101 pretrained_weights/r101_refytvos_joint.pth 1 0 29500 --backbone resnet101 --expression_file train_meta_expressions_promptaction.json --use_weights --actionvos_path ../dataset_visor --epochs 6 --lr_drop 3 5 --save_interval 3 5 | ``` 6 | 7 | where **actionvos_dirs/r101** is the path for saving training logs and checkpoints. 8 | 9 | **pretrained_weights/r101_refytvos_joint.pth** is the base RVOS model path. 10 | 11 | **1** is number of gpus. 12 | **0** is visible gpus. E.g., use **4 0,1,2,3** when using 4 gpus. 13 | 14 | **29500** is the port index. If you are running parallel scripts, change this index to any other number. 15 | 16 | For all other parameters, check [opts.py](opts.py) for explanations. 17 | -------------------------------------------------------------------------------- /RF_ActionVOS/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from util import box_ops 6 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 7 | accuracy, get_world_size, interpolate, 8 | is_dist_avail_and_initialized, inverse_sigmoid) 9 | 10 | from .segmentation import (dice_loss, sigmoid_focal_loss, sigmoid_focal_loss_weighted) 11 | 12 | from einops import rearrange 13 | 14 | class SetCriterion(nn.Module): 15 | """ This class computes the loss for ReferFormer. 16 | The process happens in two steps: 17 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 18 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 19 | """ 20 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, focal_alpha=0.25): 21 | """ Create the criterion. 22 | Parameters: 23 | num_classes: number of object categories, omitting the special no-object category 24 | matcher: module able to compute a matching between targets and proposals 25 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 26 | eos_coef: relative classification weight applied to the no-object category 27 | losses: list of all the losses to be applied. See get_loss for list of available losses. 28 | """ 29 | super().__init__() 30 | self.num_classes = num_classes 31 | self.matcher = matcher 32 | self.weight_dict = weight_dict 33 | self.eos_coef = eos_coef 34 | self.losses = losses 35 | empty_weight = torch.ones(self.num_classes + 1) 36 | empty_weight[-1] = self.eos_coef 37 | self.register_buffer('empty_weight', empty_weight) 38 | self.focal_alpha = focal_alpha 39 | self.mask_out_stride = 4 40 | 41 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 42 | """Classification loss (NLL) 43 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 44 | """ 45 | assert 'pred_logits' in outputs 46 | src_logits = outputs['pred_logits'] 47 | _, nf, nq = src_logits.shape[:3] 48 | src_logits = rearrange(src_logits, 'b t q k -> b (t q) k') 49 | 50 | # judge the valid frames 51 | valid_indices = [] 52 | valids = [target['valid'] for target in targets] 53 | for valid, (indice_i, indice_j) in zip(valids, indices): 54 | valid_ind = valid.nonzero().flatten() 55 | valid_i = valid_ind * nq + indice_i 56 | valid_j = valid_ind + indice_j * nf 57 | valid_indices.append((valid_i, valid_j)) 58 | 59 | idx = self._get_src_permutation_idx(valid_indices) # NOTE: use valid indices 60 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, valid_indices)]) 61 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 62 | dtype=torch.int64, device=src_logits.device) 63 | if self.num_classes == 1: # binary referred, positive 64 | target_classes[idx] = 0 65 | else: 66 | target_classes[idx] = target_classes_o 67 | 68 | target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], 69 | dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) 70 | target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) 71 | 72 | target_classes_onehot = target_classes_onehot[:,:,:-1] 73 | #print('target_loss_label',target_classes_onehot) 74 | loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] 75 | losses = {'loss_ce': loss_ce} 76 | 77 | if log: 78 | # TODO this should probably be a separate loss, not hacked in this one here 79 | pass 80 | return losses 81 | 82 | def loss_positive_labels(self, outputs, targets, indices, num_boxes, log=True): 83 | """ 84 | edited label loss to classify pos/neg 85 | """ 86 | assert 'pred_positives' in outputs 87 | src_logits = outputs['pred_positives'] 88 | target_classes = torch.full(src_logits.shape[:2], 0, 89 | dtype=torch.float, device=src_logits.device) 90 | for id,t in enumerate(targets): 91 | if t['positive'].any(): 92 | target_classes[id] = 1 93 | loss_ce = sigmoid_focal_loss(src_logits, target_classes, num_boxes=src_logits.shape[0], alpha=self.focal_alpha, gamma=2) 94 | losses = {'loss_positive_labels': loss_ce} 95 | 96 | if log: 97 | # TODO this should probably be a separate loss, not hacked in this one here 98 | pass 99 | return losses 100 | 101 | def loss_boxes(self, outputs, targets, indices, num_boxes): 102 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 103 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 104 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 105 | """ 106 | assert 'pred_boxes' in outputs 107 | src_boxes = outputs['pred_boxes'] 108 | bs, nf, nq = src_boxes.shape[:3] 109 | src_boxes = src_boxes.transpose(1, 2) 110 | 111 | idx = self._get_src_permutation_idx(indices) 112 | src_boxes = src_boxes[idx] 113 | src_boxes = src_boxes.flatten(0, 1) # [b*t, 4] 114 | 115 | target_boxes = torch.cat([t['boxes'] for t in targets], dim=0) # [b*t, 4] 116 | 117 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 118 | 119 | losses = {} 120 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 121 | 122 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 123 | box_ops.box_cxcywh_to_xyxy(src_boxes), 124 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 125 | losses['loss_giou'] = loss_giou.sum() / num_boxes 126 | return losses 127 | 128 | 129 | def loss_weighted_masks(self, outputs, targets, indices, num_boxes): 130 | """Compute the losses related to the masks: the focal loss and the dice loss. 131 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 132 | targets dicts must contain the key "weights" containing a tensor of dim [nb_target_boxes, h, w] 133 | weights are for focal loss, should be the same shape as target['masks'] 134 | """ 135 | assert "pred_masks" in outputs 136 | 137 | src_idx = self._get_src_permutation_idx(indices) 138 | # tgt_idx = self._get_tgt_permutation_idx(indices) 139 | 140 | src_masks = outputs["pred_masks"] 141 | src_masks = src_masks.transpose(1, 2) 142 | 143 | # TODO use valid to mask invalid areas due to padding in loss 144 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 145 | size_divisibility=32, split=False).decompose() 146 | target_masks = target_masks.to(src_masks) 147 | 148 | # what is this step ? 149 | #print('before:') 150 | #print([torch.unique(t['weights']) for t in targets]) 151 | weights, valid = nested_tensor_from_tensor_list([t["weights"] for t in targets], 152 | size_divisibility=32, split=False).decompose() 153 | weights = weights.to(src_masks) 154 | #print('after:') 155 | #print(torch.unique(weights)) 156 | 157 | # downsample ground truth masks with ratio mask_out_stride 158 | start = int(self.mask_out_stride // 2) 159 | im_h, im_w = target_masks.shape[-2:] 160 | 161 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 162 | weights = weights[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 163 | assert target_masks.size(2) * self.mask_out_stride == im_h 164 | assert target_masks.size(3) * self.mask_out_stride == im_w 165 | 166 | src_masks = src_masks[src_idx] 167 | # upsample predictions to the target size 168 | # src_masks = interpolate(src_masks, size=target_masks.shape[-2:], mode="bilinear", align_corners=False) 169 | src_masks = src_masks.flatten(1) # [b, thw] 170 | 171 | target_masks = target_masks.flatten(1) # [b, thw] 172 | weights = weights.flatten(1) 173 | 174 | losses = { 175 | "loss_mask": sigmoid_focal_loss_weighted(src_masks, target_masks, num_boxes, weights), 176 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes), 177 | } 178 | return losses 179 | 180 | def loss_masks(self, outputs, targets, indices, num_boxes): 181 | """Compute the losses related to the masks: the focal loss and the dice loss. 182 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 183 | """ 184 | assert "pred_masks" in outputs 185 | src_idx = self._get_src_permutation_idx(indices) 186 | # tgt_idx = self._get_tgt_permutation_idx(indices) 187 | 188 | src_masks = outputs["pred_masks"] 189 | src_masks = src_masks.transpose(1, 2) 190 | 191 | # TODO use valid to mask invalid areas due to padding in loss 192 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 193 | size_divisibility=32, split=False).decompose() 194 | target_masks = target_masks.to(src_masks) 195 | 196 | # downsample ground truth masks with ratio mask_out_stride 197 | start = int(self.mask_out_stride // 2) 198 | im_h, im_w = target_masks.shape[-2:] 199 | 200 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 201 | assert target_masks.size(2) * self.mask_out_stride == im_h 202 | assert target_masks.size(3) * self.mask_out_stride == im_w 203 | 204 | src_masks = src_masks[src_idx] 205 | # upsample predictions to the target size 206 | # src_masks = interpolate(src_masks, size=target_masks.shape[-2:], mode="bilinear", align_corners=False) 207 | src_masks = src_masks.flatten(1) # [b, thw] 208 | 209 | target_masks = target_masks.flatten(1) # [b, thw] 210 | 211 | losses = { 212 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), 213 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes), 214 | } 215 | return losses 216 | 217 | def _get_src_permutation_idx(self, indices): 218 | # permute predictions following indices 219 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 220 | src_idx = torch.cat([src for (src, _) in indices]) 221 | return batch_idx, src_idx 222 | 223 | def _get_tgt_permutation_idx(self, indices): 224 | # permute targets following indices 225 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 226 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 227 | return batch_idx, tgt_idx 228 | 229 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 230 | loss_map = { 231 | 'labels': self.loss_labels, 232 | 'boxes': self.loss_boxes, 233 | 'masks': self.loss_masks, 234 | 'positive_labels': self.loss_positive_labels, 235 | 'weighted_masks': self.loss_weighted_masks, 236 | } 237 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 238 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 239 | 240 | def forward(self, outputs, targets): 241 | """ This performs the loss computation. 242 | Parameters: 243 | outputs: dict of tensors, see the output specification of the model for the format 244 | targets: list of dicts, such that len(targets) == batch_size. 245 | The expected keys in each dict depends on the losses applied, see each loss' doc 246 | """ 247 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} 248 | # Retrieve the matching between the outputs of the last layer and the targets 249 | indices = self.matcher(outputs_without_aux, targets) 250 | 251 | # Compute the average number of target boxes accross all nodes, for normalization purposes 252 | target_valid = torch.stack([t["valid"] for t in targets], dim=0).reshape(-1) # [B, T] -> [B*T] 253 | num_boxes = target_valid.sum().item() 254 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 255 | if is_dist_avail_and_initialized(): 256 | torch.distributed.all_reduce(num_boxes) 257 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 258 | 259 | # Compute all the requested losses 260 | losses = {} 261 | for loss in self.losses: 262 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 263 | 264 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 265 | if 'aux_outputs' in outputs: 266 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 267 | indices = self.matcher(aux_outputs, targets) 268 | for loss in self.losses: 269 | kwargs = {} 270 | if loss == 'labels': 271 | # Logging is enabled only for the last layer 272 | kwargs = {'log': False} 273 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) 274 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 275 | losses.update(l_dict) 276 | 277 | return losses 278 | 279 | 280 | -------------------------------------------------------------------------------- /RF_ActionVOS/inference_actionvos.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inference code for ReferFormer, on Ref-Youtube-VOS 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | ''' 5 | import argparse 6 | import json 7 | import random 8 | import time 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | 14 | import util.misc as utils 15 | from models import build_model 16 | import torchvision.transforms as T 17 | import matplotlib.pyplot as plt 18 | import os 19 | import cv2 20 | from PIL import Image, ImageDraw 21 | import math 22 | import torch.nn.functional as F 23 | import json 24 | 25 | import opts 26 | from tqdm import tqdm 27 | 28 | import multiprocessing as mp 29 | import threading 30 | 31 | from tools.colormap import colormap 32 | 33 | 34 | # colormap 35 | color_list = colormap() 36 | color_list = color_list.astype('uint8').tolist() 37 | 38 | # build transform 39 | transform = T.Compose([ 40 | T.Resize(360), 41 | T.ToTensor(), 42 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | ]) 44 | 45 | 46 | def main(args): 47 | args.masks = True 48 | args.batch_size == 1 49 | print("Inference only supports for batch size = 1") 50 | 51 | # fix the seed for reproducibility 52 | seed = args.seed + utils.get_rank() 53 | torch.manual_seed(seed) 54 | np.random.seed(seed) 55 | random.seed(seed) 56 | 57 | split = args.split 58 | # save path 59 | output_dir = args.output_dir 60 | save_path_prefix = os.path.join(output_dir, split) 61 | if not os.path.exists(save_path_prefix): 62 | os.makedirs(save_path_prefix) 63 | 64 | save_visualize_path_prefix = os.path.join(output_dir, split + '_images') 65 | if args.visualize: 66 | if not os.path.exists(save_visualize_path_prefix): 67 | os.makedirs(save_visualize_path_prefix) 68 | 69 | # load data 70 | root = Path(args.actionvos_path) 71 | img_folder = os.path.join(root, "JPEGImages_Sparse", split) 72 | meta_file = os.path.join(root, "ImageSets", args.expression_file) 73 | with open(meta_file, "r") as f: 74 | data = json.load(f)["videos"] 75 | video_list = list(data.keys()) 76 | #assert len(video_list) == 202, 'error: incorrect number of validation videos' 77 | 78 | # create subprocess 79 | thread_num = args.ngpu 80 | global result_dict 81 | result_dict = mp.Manager().dict() 82 | 83 | processes = [] 84 | lock = threading.Lock() 85 | 86 | video_num = len(video_list) 87 | per_thread_video_num = video_num // thread_num 88 | 89 | start_time = time.time() 90 | print('Start inference') 91 | for i in range(thread_num): 92 | if i == thread_num - 1: 93 | sub_video_list = video_list[i * per_thread_video_num:] 94 | else: 95 | sub_video_list = video_list[i * per_thread_video_num: (i + 1) * per_thread_video_num] 96 | p = mp.Process(target=sub_processor, args=(lock, i, args, data, 97 | save_path_prefix, save_visualize_path_prefix, 98 | img_folder, sub_video_list)) 99 | p.start() 100 | processes.append(p) 101 | 102 | for p in processes: 103 | p.join() 104 | 105 | end_time = time.time() 106 | total_time = end_time - start_time 107 | 108 | result_dict = dict(result_dict) 109 | num_all_frames_gpus = 0 110 | for pid, num_all_frames in result_dict.items(): 111 | num_all_frames_gpus += num_all_frames 112 | 113 | print("Total inference time: %.4f s" %(total_time)) 114 | 115 | def sub_processor(lock, pid, args, data, save_path_prefix, save_visualize_path_prefix, img_folder, video_list): 116 | text = 'processor %d' % pid 117 | with lock: 118 | progress = tqdm( 119 | total=len(video_list), 120 | position=pid, 121 | desc=text, 122 | ncols=0 123 | ) 124 | torch.cuda.set_device(pid) 125 | 126 | # model 127 | model, criterion, _ = build_model(args) 128 | device = args.device 129 | model.to(device) 130 | 131 | model_without_ddp = model 132 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 133 | 134 | if pid == 0: 135 | print('number of params:', n_parameters) 136 | 137 | if args.resume: 138 | checkpoint = torch.load(args.resume, map_location='cpu') 139 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 140 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 141 | if len(missing_keys) > 0: 142 | print('Missing Keys: {}'.format(missing_keys)) 143 | if len(unexpected_keys) > 0: 144 | print('Unexpected Keys: {}'.format(unexpected_keys)) 145 | else: 146 | raise ValueError('Please specify the checkpoint for inference.') 147 | 148 | # get palette 149 | palette_img = '../annotations/00000.png' 150 | palette = Image.open(palette_img).getpalette() 151 | 152 | # start inference 153 | num_all_frames = 0 154 | model.eval() 155 | 156 | # 1. For each video 157 | for video in video_list: 158 | metas = [] # list[dict], length is number of expressions 159 | 160 | expressions = data[video]["expressions"] 161 | expression_list = list(expressions.keys()) 162 | num_expressions = len(expression_list) 163 | video_len = len(data[video]["frames"]) 164 | 165 | # read all the anno meta 166 | for i in range(num_expressions): 167 | meta = {} 168 | meta["video"] = video 169 | meta["exp"] = expressions[expression_list[i]]["exp"] 170 | meta["exp_id"] = expression_list[i] 171 | meta["frames"] = data[video]["frames"] 172 | metas.append(meta) 173 | meta = metas 174 | 175 | # 2. For each expression, add by oyly, join results 176 | anno_logits = [] 177 | anno_masks = [] 178 | for i in range(num_expressions): 179 | video_name = meta[i]["video"] 180 | exp = meta[i]["exp"] 181 | exp_id = meta[i]["exp_id"] 182 | frames = meta[i]["frames"] 183 | 184 | video_len = len(frames) 185 | # NOTE: the im2col_step for MSDeformAttention is set as 64 186 | # so the max length for a clip is 64 187 | # following inference_davis.py, we set max length to 36 188 | # for each clip 189 | all_pred_logits = [] 190 | all_pred_masks = [] 191 | for clip_id in range(0, video_len, 36): 192 | frames_ids = [x for x in range(video_len)] 193 | clip_frames_ids = frames_ids[clip_id : clip_id + 36] 194 | clip_len = len(clip_frames_ids) 195 | imgs = [] 196 | for t in clip_frames_ids: 197 | frame = frames[t] 198 | img_path = os.path.join(img_folder, video_name, frame + ".jpg") 199 | img = Image.open(img_path).convert('RGB') 200 | origin_w, origin_h = img.size 201 | imgs.append(transform(img)) # list[Img] 202 | 203 | imgs = torch.stack(imgs, dim=0).to(args.device) # [clip_len, 3, h, w] 204 | img_h, img_w = imgs.shape[-2:] 205 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device) 206 | target = {"size": size} 207 | 208 | with torch.no_grad(): 209 | outputs = model([imgs], [exp], [target]) 210 | 211 | pred_logits = outputs["pred_logits"][0] 212 | pred_boxes = outputs["pred_boxes"][0] 213 | pred_masks = outputs["pred_masks"][0] 214 | pred_ref_points = outputs["reference_points"][0] 215 | 216 | # according to pred_logits, select the query index 217 | pred_scores = pred_logits.sigmoid() # [t, q, k] 218 | pred_scores = pred_scores.mean(0) # [q, k] 219 | max_scores, _ = pred_scores.max(-1) # [q,] 220 | max_scores, max_ind = max_scores.max(-1) # [1,] 221 | max_inds = max_ind.repeat(clip_len) 222 | pred_masks = pred_masks[range(clip_len), max_inds, ...] # [t, h, w] 223 | pred_masks = pred_masks.unsqueeze(0) 224 | 225 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) 226 | pred_masks = pred_masks.sigmoid().squeeze(0) 227 | 228 | 229 | # store the clip results 230 | pred_logits = pred_logits[range(clip_len), max_inds] 231 | 232 | # use positive label classification to filter negative objects 233 | if args.use_positive_cls: 234 | if outputs['pred_positives'][0].sigmoid() < args.pos_cls_thres: 235 | pred_masks.zero_() 236 | 237 | #all_pred_boxes = pred_boxes[range(clip_len), max_inds] 238 | #all_pred_ref_points = pred_ref_points[range(clip_len), max_inds] 239 | all_pred_logits.append(pred_logits) 240 | all_pred_masks.append(pred_masks) 241 | 242 | all_pred_logits = torch.cat(all_pred_logits, dim=0) # (video_len, K) 243 | all_pred_masks = torch.cat(all_pred_masks, dim=0) # (video_len, h, w) 244 | anno_logits.append(all_pred_logits) 245 | anno_masks.append(all_pred_masks) 246 | # join all results to one 247 | # handle a complete image (all objects of a annotator) 248 | anno_logits = torch.stack(anno_logits) # [num_obj, video_len, k] 249 | anno_masks = torch.stack(anno_masks) # [num_obj, video_len, h, w] 250 | t, h, w = anno_masks.shape[-3:] 251 | anno_masks[anno_masks < 0.5] = 0.0 252 | background = 0.1 * torch.ones(1, t, h, w).to(args.device) 253 | anno_masks = torch.cat([background, anno_masks], dim=0) # [num_obj+1, video_len, h, w] 254 | out_masks = torch.argmax(anno_masks, dim=0) # int, the value indicate which object, [video_len, h, w] 255 | 256 | out_masks = out_masks.detach().cpu().numpy().astype(np.uint8) # [video_len, h, w] 257 | 258 | # save results 259 | anno_save_path = os.path.join(save_path_prefix, video) 260 | if not os.path.exists(anno_save_path): 261 | os.makedirs(anno_save_path) 262 | for f in range(out_masks.shape[0]): 263 | img_E = Image.fromarray(out_masks[f]) 264 | img_E.putpalette(palette) 265 | img_E.save(os.path.join(anno_save_path, '{}.png'.format(data[video]["frames"][f]))) 266 | 267 | with lock: 268 | progress.update(1) 269 | result_dict[str(pid)] = num_all_frames 270 | with lock: 271 | progress.close() 272 | 273 | 274 | # visuaize functions 275 | def box_cxcywh_to_xyxy(x): 276 | x_c, y_c, w, h = x.unbind(1) 277 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 278 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 279 | return torch.stack(b, dim=1) 280 | 281 | def rescale_bboxes(out_bbox, size): 282 | img_w, img_h = size 283 | b = box_cxcywh_to_xyxy(out_bbox) 284 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 285 | return b 286 | 287 | 288 | # Visualization functions 289 | def draw_reference_points(draw, reference_points, img_size, color): 290 | W, H = img_size 291 | for i, ref_point in enumerate(reference_points): 292 | init_x, init_y = ref_point 293 | x, y = W * init_x, H * init_y 294 | cur_color = color 295 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4) 296 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4) 297 | 298 | def draw_sample_points(draw, sample_points, img_size, color_list): 299 | alpha = 255 300 | for i, samples in enumerate(sample_points): 301 | for sample in samples: 302 | x, y = sample 303 | cur_color = color_list[i % len(color_list)][::-1] 304 | cur_color += [alpha] 305 | draw.ellipse((x-2, y-2, x+2, y+2), 306 | fill=tuple(cur_color), outline=tuple(cur_color), width=1) 307 | 308 | def vis_add_mask(img, mask, color): 309 | origin_img = np.asarray(img.convert('RGB')).copy() 310 | color = np.array(color) 311 | 312 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np 313 | mask = mask > 0.5 314 | 315 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 316 | origin_img = Image.fromarray(origin_img) 317 | return origin_img 318 | 319 | 320 | 321 | if __name__ == '__main__': 322 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()]) 323 | args = parser.parse_args() 324 | args.ngpu = 1 325 | main(args) 326 | -------------------------------------------------------------------------------- /RF_ActionVOS/main_actionvos.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script of ReferFormer-ActionVOS 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import argparse 6 | import datetime 7 | import json 8 | import random 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | 16 | import util.misc as utils 17 | import datasets.samplers as samplers 18 | from datasets import build_dataset, get_coco_api_from_dataset 19 | from engine import train_one_epoch, evaluate, evaluate_a2d 20 | from models import build_model 21 | 22 | from tools.load_pretrained_weights import pre_trained_model_to_finetune 23 | 24 | import opts 25 | 26 | 27 | 28 | def main(args): 29 | args.masks = True 30 | 31 | utils.init_distributed_mode(args) 32 | print("git:\n {}\n".format(utils.get_sha())) 33 | #print(args) 34 | 35 | print(f'\n Run on {args.dataset_file} dataset.') 36 | print('\n') 37 | 38 | device = torch.device(args.device) 39 | 40 | # fix the seed for reproducibility 41 | seed = args.seed + utils.get_rank() 42 | print('seed=',seed) 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | 47 | model, criterion, postprocessor = build_model(args) 48 | model.to(device) 49 | 50 | model_without_ddp = model 51 | if args.distributed: 52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 53 | model_without_ddp = model.module 54 | 55 | # for n, p in model_without_ddp.named_parameters(): 56 | # print(n) 57 | 58 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 59 | print('number of params:', n_parameters) 60 | 61 | def match_name_keywords(n, name_keywords): 62 | out = False 63 | for b in name_keywords: 64 | if b in n: 65 | out = True 66 | break 67 | return out 68 | 69 | 70 | param_dicts = [ 71 | { 72 | "params": 73 | [p for n, p in model_without_ddp.named_parameters() 74 | if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_text_encoder_names) 75 | and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 76 | "lr": args.lr, 77 | }, 78 | { 79 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad], 80 | "lr": args.lr_backbone, 81 | }, 82 | { 83 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_text_encoder_names) and p.requires_grad], 84 | "lr": args.lr_text_encoder, 85 | }, 86 | { 87 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 88 | "lr": args.lr * args.lr_linear_proj_mult, 89 | } 90 | ] 91 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 92 | weight_decay=args.weight_decay) 93 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop) 94 | 95 | # build actionvos train-val 96 | dataset_train = build_dataset(args.dataset_file, image_set='train', args=args) 97 | 98 | if args.distributed: 99 | if args.cache_mode: 100 | sampler_train = samplers.NodeDistributedSampler(dataset_train) 101 | else: 102 | sampler_train = samplers.DistributedSampler(dataset_train) 103 | else: 104 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 105 | 106 | batch_sampler_train = torch.utils.data.BatchSampler( 107 | sampler_train, args.batch_size, drop_last=True) 108 | 109 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 110 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 111 | 112 | # for Ref-Youtube-VOS and A2D-Sentences 113 | # finetune using the pretrained weights on Ref-COCO 114 | if args.dataset_file != "davis" and args.dataset_file != "jhmdb" and args.pretrained_weights is not None: 115 | print("============================================>") 116 | print("Load pretrained weights from {} ...".format(args.pretrained_weights)) 117 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu") 118 | checkpoint_dict = pre_trained_model_to_finetune(checkpoint, args) 119 | model_without_ddp.load_state_dict(checkpoint_dict, strict=False) 120 | print("============================================>") 121 | 122 | 123 | output_dir = Path(args.output_dir) 124 | if args.resume: 125 | if args.resume.startswith('https'): 126 | checkpoint = torch.hub.load_state_dict_from_url( 127 | args.resume, map_location='cpu', check_hash=True) 128 | else: 129 | checkpoint = torch.load(args.resume, map_location='cpu') 130 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 131 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 132 | if len(missing_keys) > 0: 133 | print('Missing Keys: {}'.format(missing_keys)) 134 | if len(unexpected_keys) > 0: 135 | print('Unexpected Keys: {}'.format(unexpected_keys)) 136 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 137 | import copy 138 | p_groups = copy.deepcopy(optimizer.param_groups) 139 | optimizer.load_state_dict(checkpoint['optimizer']) 140 | for pg, pg_old in zip(optimizer.param_groups, p_groups): 141 | pg['lr'] = pg_old['lr'] 142 | pg['initial_lr'] = pg_old['initial_lr'] 143 | print(optimizer.param_groups) 144 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 145 | # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). 146 | args.override_resumed_lr_drop = True 147 | if args.override_resumed_lr_drop: 148 | print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.') 149 | lr_scheduler.step_size = args.lr_drop 150 | lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 151 | lr_scheduler.step(lr_scheduler.last_epoch) 152 | args.start_epoch = checkpoint['epoch'] + 1 153 | 154 | if args.eval: 155 | assert args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb', \ 156 | 'Only A2D-Sentences and JHMDB-Sentences datasets support evaluation' 157 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args) 158 | return 159 | 160 | 161 | print("Start training") 162 | start_time = time.time() 163 | for epoch in range(args.start_epoch, args.epochs): 164 | if args.distributed: 165 | sampler_train.set_epoch(epoch) 166 | train_stats = train_one_epoch( 167 | model, criterion, data_loader_train, optimizer, device, epoch, 168 | args.clip_max_norm) 169 | lr_scheduler.step() 170 | if args.output_dir: 171 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 172 | # extra checkpoint before LR drop and every epochs 173 | #if (epoch + 1) in args.lr_drop: 174 | if (epoch + 1) % args.save_interval == 0: 175 | checkpoint_paths.append(output_dir / f'checkpoint{(epoch+1):04}.pth') 176 | for checkpoint_path in checkpoint_paths: 177 | utils.save_on_master({ 178 | 'model': model_without_ddp.state_dict(), 179 | 'optimizer': optimizer.state_dict(), 180 | 'lr_scheduler': lr_scheduler.state_dict(), 181 | 'epoch': epoch, 182 | 'args': args, 183 | }, checkpoint_path) 184 | 185 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 186 | 'epoch': epoch, 187 | 'n_parameters': n_parameters} 188 | 189 | if args.dataset_file == 'a2d': 190 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args) 191 | log_stats.update({**{f'{k}': v for k, v in test_stats.items()}}) 192 | 193 | if args.output_dir and utils.is_main_process(): 194 | with (output_dir / "log.txt").open("a") as f: 195 | f.write(json.dumps(log_stats) + "\n") 196 | 197 | 198 | total_time = time.time() - start_time 199 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 200 | print('Training time {}'.format(total_time_str)) 201 | 202 | 203 | if __name__ == '__main__': 204 | parser = argparse.ArgumentParser('ReferFormer training and evaluation script', parents=[opts.get_args_parser()]) 205 | args = parser.parse_args() 206 | args.dataset_file = 'actionvos' 207 | if args.all_pos: 208 | args.dataset_file = 'actionvos_allpos' 209 | if args.output_dir: 210 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 211 | main(args) 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /RF_ActionVOS/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('ReferFormer training and inference scripts.', add_help=False) 5 | parser.add_argument('--lr', default=1e-4, type=float) 6 | parser.add_argument('--lr_backbone', default=5e-5, type=float) 7 | parser.add_argument('--lr_backbone_names', default=['backbone.0'], type=str, nargs='+') 8 | parser.add_argument('--lr_text_encoder', default=1e-5, type=float) 9 | parser.add_argument('--lr_text_encoder_names', default=['text_encoder'], type=str, nargs='+') 10 | parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+') 11 | parser.add_argument('--lr_linear_proj_mult', default=1.0, type=float) 12 | parser.add_argument('--batch_size', default=1, type=int) 13 | parser.add_argument('--weight_decay', default=5e-4, type=float) 14 | parser.add_argument('--epochs', default=10, type=int) 15 | parser.add_argument('--save_interval', default=5, type=int) 16 | parser.add_argument('--lr_drop', default=3, type=int, nargs='+') 17 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 18 | help='gradient clipping max norm') 19 | 20 | # Model parameters 21 | # load the pretrained weights 22 | parser.add_argument('--pretrained_weights', type=str, default=None, 23 | help="Path to the pretrained model.") 24 | 25 | # Variants of Deformable DETR 26 | parser.add_argument('--with_box_refine', default=False, action='store_true') 27 | parser.add_argument('--two_stage', default=False, action='store_true') # NOTE: must be false 28 | 29 | # * Backbone 30 | # ["resnet50", "resnet101", "swin_t_p4w7", "swin_s_p4w7", "swin_b_p4w7", "swin_l_p4w7"] 31 | # ["video_swin_t_p4w7", "video_swin_s_p4w7", "video_swin_b_p4w7"] 32 | parser.add_argument('--backbone', default='resnet50', type=str, 33 | help="Name of the convolutional backbone to use") 34 | parser.add_argument('--backbone_pretrained', default=None, type=str, 35 | help="if use swin backbone and train from scratch, the path to the pretrained weights") 36 | parser.add_argument('--use_checkpoint', action='store_true', help='whether use checkpoint for swin/video swin backbone') 37 | parser.add_argument('--dilation', action='store_true', # DC5 38 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 39 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 40 | help="Type of positional embedding to use on top of the image features") 41 | parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels') 42 | 43 | # * Transformer 44 | parser.add_argument('--enc_layers', default=4, type=int, 45 | help="Number of encoding layers in the transformer") 46 | parser.add_argument('--dec_layers', default=4, type=int, 47 | help="Number of decoding layers in the transformer") 48 | parser.add_argument('--dim_feedforward', default=2048, type=int, 49 | help="Intermediate size of the feedforward layers in the transformer blocks") 50 | parser.add_argument('--hidden_dim', default=256, type=int, 51 | help="Size of the embeddings (dimension of the transformer)") 52 | parser.add_argument('--dropout', default=0.1, type=float, 53 | help="Dropout applied in the transformer") 54 | parser.add_argument('--nheads', default=8, type=int, 55 | help="Number of attention heads inside the transformer's attentions") 56 | parser.add_argument('--num_frames', default=5, type=int, 57 | help="Number of clip frames for training") 58 | parser.add_argument('--num_queries', default=5, type=int, 59 | help="Number of query slots, all frames share the same queries") 60 | parser.add_argument('--dec_n_points', default=4, type=int) 61 | parser.add_argument('--enc_n_points', default=4, type=int) 62 | parser.add_argument('--pre_norm', action='store_true') 63 | # for text 64 | parser.add_argument('--freeze_text_encoder', action='store_true') # default: False 65 | 66 | # * Segmentation 67 | parser.add_argument('--use_weights', action='store_true', help="use action-guided weights for focal loss") 68 | parser.add_argument('--all_pos', action='store_true', 69 | help="in training, keep all masks as positive") 70 | parser.add_argument('--use_positive_cls', action='store_true', 71 | help="use an extra positive classification head") 72 | parser.add_argument('--pos_cls_thres', default=0.75, type=float, 73 | help="in inference, use positive classification results and the classification threshold") 74 | parser.add_argument('--masks', action='store_true', 75 | help="Train segmentation head if the flag is provided") 76 | parser.add_argument('--mask_dim', default=256, type=int, 77 | help="Size of the mask embeddings (dimension of the dynamic mask conv)") 78 | parser.add_argument('--controller_layers', default=3, type=int, 79 | help="Dynamic conv layer number") 80 | parser.add_argument('--dynamic_mask_channels', default=8, type=int, 81 | help="Dynamic conv final channel number") 82 | parser.add_argument('--no_rel_coord', dest='rel_coord', action='store_false', 83 | help="Disables relative coordinates") 84 | 85 | # Loss 86 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 87 | help="Disables auxiliary decoding losses (loss at each layer)") 88 | # * Matcher 89 | parser.add_argument('--set_cost_class', default=2, type=float, 90 | help="Class coefficient in the matching cost") 91 | parser.add_argument('--set_cost_bbox', default=5, type=float, 92 | help="L1 box coefficient in the matching cost") 93 | parser.add_argument('--set_cost_giou', default=2, type=float, 94 | help="giou box coefficient in the matching cost") 95 | parser.add_argument('--set_cost_mask', default=2, type=float, 96 | help="mask coefficient in the matching cost") 97 | parser.add_argument('--set_cost_dice', default=5, type=float, 98 | help="mask coefficient in the matching cost") 99 | # * Loss coefficients 100 | parser.add_argument('--mask_loss_coef', default=2, type=float) 101 | parser.add_argument('--dice_loss_coef', default=5, type=float) 102 | parser.add_argument('--cls_loss_coef', default=2, type=float) 103 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 104 | parser.add_argument('--giou_loss_coef', default=2, type=float) 105 | parser.add_argument('--eos_coef', default=0.1, type=float, 106 | help="Relative classification weight of the no-object class") 107 | parser.add_argument('--focal_alpha', default=0.25, type=float) 108 | 109 | # dataset parameters 110 | # ['ytvos', 'davis', 'a2d', 'jhmdb', 'refcoco', 'refcoco+', 'refcocog', 'all'] 111 | # 'all': using the three ref datasets for pretraining 112 | parser.add_argument('--dataset_file', default='ytvos', help='Dataset name') 113 | parser.add_argument('--expression_file', default='meta_expressions.json', help='Annotation exp name') 114 | parser.add_argument('--actionvos_path', type=str, default='../dataset_visor') 115 | parser.add_argument('--coco_path', type=str, default='data/coco') 116 | parser.add_argument('--ytvos_path', type=str, default='data/ref-youtube-vos') 117 | parser.add_argument('--davis_path', type=str, default='data/ref-davis') 118 | parser.add_argument('--a2d_path', type=str, default='data/a2d_sentences') 119 | parser.add_argument('--jhmdb_path', type=str, default='data/jhmdb_sentences') 120 | parser.add_argument('--max_skip', default=3, type=int, help="max skip frame number") 121 | parser.add_argument('--max_size', default=512, type=int, help="max size for the frame") 122 | # changed max_size from 640 to 512 due to cuda OOM with video-swin-b backbone 123 | parser.add_argument('--binary', action='store_true') 124 | parser.add_argument('--remove_difficult', action='store_true') 125 | 126 | parser.add_argument('--output_dir', default='output', 127 | help='path where to save, empty for no saving') 128 | parser.add_argument('--device', default='cuda', 129 | help='device to use for training / testing') 130 | parser.add_argument('--seed', default=42, type=int) 131 | parser.add_argument('--resume', default='', help='resume from checkpoint') 132 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 133 | help='start epoch') 134 | parser.add_argument('--eval', action='store_true') 135 | parser.add_argument('--num_workers', default=16, type=int) 136 | 137 | # test setting 138 | parser.add_argument('--threshold', default=0.5, type=float) # binary threshold for mask 139 | parser.add_argument('--ngpu', default=8, type=int, help='gpu number when inference for ref-ytvos and ref-davis') 140 | parser.add_argument('--split', default='val', type=str, choices=['val', 'test']) 141 | parser.add_argument('--visualize', action='store_true', help='whether visualize the masks during inference') 142 | 143 | # distributed training parameters 144 | parser.add_argument('--world_size', default=1, type=int, 145 | help='number of distributed processes') 146 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 147 | parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory') 148 | return parser 149 | 150 | 151 | -------------------------------------------------------------------------------- /RF_ActionVOS/referformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | ReferFormer model class. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | import os 10 | import math 11 | from util import box_ops 12 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 13 | nested_tensor_from_videos_list, 14 | accuracy, get_world_size, interpolate, 15 | is_dist_avail_and_initialized, inverse_sigmoid) 16 | 17 | from .position_encoding import PositionEmbeddingSine1D 18 | from .backbone import build_backbone 19 | from .deformable_transformer import build_deforamble_transformer 20 | from .segmentation import CrossModalFPNDecoder, VisionLanguageFusionModule 21 | from .matcher import build_matcher 22 | from .criterion import SetCriterion 23 | from .postprocessors import build_postprocessors 24 | 25 | from transformers import BertTokenizer, BertModel, RobertaModel, RobertaTokenizerFast 26 | 27 | import copy 28 | from einops import rearrange, repeat 29 | 30 | def _get_clones(module, N): 31 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 32 | 33 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # this disables a huggingface tokenizer warning (printed every epoch) 34 | 35 | class ReferFormer(nn.Module): 36 | """ This is the ReferFormer module that performs referring video object detection """ 37 | def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, 38 | num_frames, mask_dim, dim_feedforward, 39 | controller_layers, dynamic_mask_channels, 40 | aux_loss=False, with_box_refine=False, two_stage=False, 41 | freeze_text_encoder=False, rel_coord=True, 42 | positive_cls=False): 43 | """ Initializes the model. 44 | Parameters: 45 | backbone: torch module of the backbone to be used. See backbone.py 46 | transformer: torch module of the transformer architecture. See transformer.py 47 | num_classes: number of object classes 48 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 49 | ReferFormer can detect in a video. For ytvos, we recommend 5 queries for each frame. 50 | num_frames: number of clip frames 51 | mask_dim: dynamic conv inter layer channel number. 52 | dim_feedforward: vision-language fusion module ffn channel number. 53 | dynamic_mask_channels: the mask feature output channel number. 54 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 55 | """ 56 | super().__init__() 57 | self.num_queries = num_queries 58 | self.transformer = transformer 59 | hidden_dim = transformer.d_model 60 | self.hidden_dim = hidden_dim 61 | self.class_embed = nn.Linear(hidden_dim, num_classes) 62 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 63 | self.num_feature_levels = num_feature_levels 64 | 65 | # add an extra cls head 66 | self.positive_cls = positive_cls 67 | if self.positive_cls: 68 | self.positive_cls_embed = nn.Linear(hidden_dim, 1) 69 | 70 | # Build Transformer 71 | # NOTE: different deformable detr, the query_embed out channels is 72 | # hidden_dim instead of hidden_dim * 2 73 | # This is because, the input to the decoder is text embedding feature 74 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 75 | 76 | # follow deformable-detr, we use the last three stages of backbone 77 | if num_feature_levels > 1: 78 | num_backbone_outs = len(backbone.strides[-3:]) 79 | input_proj_list = [] 80 | for _ in range(num_backbone_outs): 81 | in_channels = backbone.num_channels[-3:][_] 82 | input_proj_list.append(nn.Sequential( 83 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 84 | nn.GroupNorm(32, hidden_dim), 85 | )) 86 | for _ in range(num_feature_levels - num_backbone_outs): # downsample 2x 87 | input_proj_list.append(nn.Sequential( 88 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), 89 | nn.GroupNorm(32, hidden_dim), 90 | )) 91 | in_channels = hidden_dim 92 | self.input_proj = nn.ModuleList(input_proj_list) 93 | else: 94 | self.input_proj = nn.ModuleList([ 95 | nn.Sequential( 96 | nn.Conv2d(backbone.num_channels[-3:][0], hidden_dim, kernel_size=1), 97 | nn.GroupNorm(32, hidden_dim), 98 | )]) 99 | 100 | self.num_frames = num_frames 101 | self.mask_dim = mask_dim 102 | self.backbone = backbone 103 | self.aux_loss = aux_loss 104 | self.with_box_refine = with_box_refine 105 | assert two_stage == False, "args.two_stage must be false!" 106 | 107 | # initialization 108 | prior_prob = 0.01 109 | bias_value = -math.log((1 - prior_prob) / prior_prob) 110 | self.class_embed.bias.data = torch.ones(num_classes) * bias_value 111 | nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) 112 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) 113 | for proj in self.input_proj: 114 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 115 | nn.init.constant_(proj[0].bias, 0) 116 | 117 | num_pred = transformer.decoder.num_layers 118 | if with_box_refine: 119 | self.class_embed = _get_clones(self.class_embed, num_pred) 120 | self.bbox_embed = _get_clones(self.bbox_embed, num_pred) 121 | nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) 122 | # hack implementation for iterative bounding box refinement 123 | self.transformer.decoder.bbox_embed = self.bbox_embed 124 | else: 125 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) 126 | self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) 127 | self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) 128 | self.transformer.decoder.bbox_embed = None 129 | 130 | if self.positive_cls: 131 | self.positive_cls_embed.bias.data = torch.ones(1) * bias_value 132 | self.positive_cls_embed = _get_clones(self.positive_cls_embed, num_pred) 133 | # Build Text Encoder 134 | # self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 135 | # self.text_encoder = BertModel.from_pretrained('bert-base-cased') 136 | self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') 137 | self.text_encoder = RobertaModel.from_pretrained('roberta-base') 138 | 139 | if freeze_text_encoder: 140 | for p in self.text_encoder.parameters(): 141 | p.requires_grad_(False) 142 | 143 | # resize the bert output channel to transformer d_model 144 | self.resizer = FeatureResizer( 145 | input_feat_size=768, 146 | output_feat_size=hidden_dim, 147 | dropout=0.1, 148 | ) 149 | 150 | self.fusion_module = VisionLanguageFusionModule(d_model=hidden_dim, nhead=8) 151 | self.text_pos = PositionEmbeddingSine1D(hidden_dim, normalize=True) 152 | 153 | # Build FPN Decoder 154 | self.rel_coord = rel_coord 155 | feature_channels = [self.backbone.num_channels[0]] + 3 * [hidden_dim] 156 | self.pixel_decoder = CrossModalFPNDecoder(feature_channels=feature_channels, conv_dim=hidden_dim, 157 | mask_dim=mask_dim, dim_feedforward=dim_feedforward, norm="GN") 158 | 159 | # Build Dynamic Conv 160 | self.controller_layers = controller_layers 161 | self.in_channels = mask_dim 162 | self.dynamic_mask_channels = dynamic_mask_channels 163 | self.mask_out_stride = 4 164 | self.mask_feat_stride = 4 165 | 166 | weight_nums, bias_nums = [], [] 167 | for l in range(self.controller_layers): 168 | if l == 0: 169 | if self.rel_coord: 170 | weight_nums.append((self.in_channels + 2) * self.dynamic_mask_channels) 171 | else: 172 | weight_nums.append(self.in_channels * self.dynamic_mask_channels) 173 | bias_nums.append(self.dynamic_mask_channels) 174 | elif l == self.controller_layers - 1: 175 | weight_nums.append(self.dynamic_mask_channels * 1) # output layer c -> 1 176 | bias_nums.append(1) 177 | else: 178 | weight_nums.append(self.dynamic_mask_channels * self.dynamic_mask_channels) 179 | bias_nums.append(self.dynamic_mask_channels) 180 | 181 | self.weight_nums = weight_nums 182 | self.bias_nums = bias_nums 183 | self.num_gen_params = sum(weight_nums) + sum(bias_nums) 184 | 185 | self.controller = MLP(hidden_dim, hidden_dim, self.num_gen_params, 3) 186 | for layer in self.controller.layers: 187 | nn.init.zeros_(layer.bias) 188 | nn.init.xavier_uniform_(layer.weight) 189 | 190 | 191 | def forward(self, samples: NestedTensor, captions, targets): 192 | """ The forward expects a NestedTensor, which consists of: 193 | - samples.tensors: image sequences, of shape [num_frames x 3 x H x W] 194 | - samples.mask: a binary mask of shape [num_frames x H x W], containing 1 on padded pixels 195 | - captions: list[str] 196 | - targets: list[dict] 197 | 198 | It returns a dict with the following elements: 199 | - "pred_masks": Shape = [batch_size x num_queries x out_h x out_w] 200 | 201 | - "pred_logits": the classification logits (including no-object) for all queries. 202 | Shape= [batch_size x num_queries x num_classes] 203 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 204 | (center_x, center_y, height, width). These values are normalized in [0, 1], 205 | relative to the size of each individual image (disregarding possible padding). 206 | See PostProcess for information on how to retrieve the unnormalized bounding box. 207 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 208 | dictionnaries containing the two above keys for each decoder layer. 209 | """ 210 | # Backbone 211 | if not isinstance(samples, NestedTensor): 212 | samples = nested_tensor_from_videos_list(samples) 213 | 214 | # features (list[NestedTensor]): res2 -> res5, shape of tensors is [B*T, Ci, Hi, Wi] 215 | # pos (list[Tensor]): shape of [B*T, C, Hi, Wi] 216 | features, pos = self.backbone(samples) 217 | 218 | b = len(captions) 219 | t = pos[0].shape[0] // b 220 | #print(captions) 221 | 222 | # For A2D-Sentences and JHMDB-Sentencs dataset, only one frame is annotated for a clip 223 | if 'valid_indices' in targets[0]: 224 | valid_indices = torch.tensor([i * t + target['valid_indices'] for i, target in enumerate(targets)]).to(pos[0].device) 225 | for feature in features: 226 | feature.tensors = feature.tensors.index_select(0, valid_indices) 227 | feature.mask = feature.mask.index_select(0, valid_indices) 228 | for i, p in enumerate(pos): 229 | pos[i] = p.index_select(0, valid_indices) 230 | samples.mask = samples.mask.index_select(0, valid_indices) 231 | # t: num_frames -> 1 232 | t = 1 233 | 234 | text_features, text_sentence_features = self.forward_text(captions, device=pos[0].device) 235 | 236 | # prepare vision and text features for transformer 237 | srcs = [] 238 | masks = [] 239 | poses = [] 240 | 241 | text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c] 242 | text_word_features, text_word_masks = text_features.decompose() 243 | text_word_features = text_word_features.permute(1, 0, 2) # [length, batch_size, c] 244 | 245 | # Follow Deformable-DETR, we use the last three stages outputs from backbone 246 | for l, (feat, pos_l) in enumerate(zip(features[-3:], pos[-3:])): 247 | src, mask = feat.decompose() 248 | src_proj_l = self.input_proj[l](src) 249 | n, c, h, w = src_proj_l.shape 250 | 251 | # vision language early-fusion 252 | src_proj_l = rearrange(src_proj_l, '(b t) c h w -> (t h w) b c', b=b, t=t) 253 | src_proj_l = self.fusion_module(tgt=src_proj_l, 254 | memory=text_word_features, 255 | memory_key_padding_mask=text_word_masks, 256 | pos=text_pos, 257 | query_pos=None 258 | ) 259 | src_proj_l = rearrange(src_proj_l, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 260 | 261 | srcs.append(src_proj_l) 262 | masks.append(mask) 263 | poses.append(pos_l) 264 | assert mask is not None 265 | 266 | if self.num_feature_levels > (len(features) - 1): 267 | _len_srcs = len(features) - 1 # fpn level 268 | for l in range(_len_srcs, self.num_feature_levels): 269 | if l == _len_srcs: 270 | src = self.input_proj[l](features[-1].tensors) 271 | else: 272 | src = self.input_proj[l](srcs[-1]) 273 | m = samples.mask 274 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 275 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 276 | n, c, h, w = src.shape 277 | 278 | # vision language early-fusion 279 | src = rearrange(src, '(b t) c h w -> (t h w) b c', b=b, t=t) 280 | src = self.fusion_module(tgt=src, 281 | memory=text_word_features, 282 | memory_key_padding_mask=text_word_masks, 283 | pos=text_pos, 284 | query_pos=None 285 | ) 286 | src = rearrange(src, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 287 | 288 | srcs.append(src) 289 | masks.append(mask) 290 | poses.append(pos_l) 291 | 292 | # Transformer 293 | query_embeds = self.query_embed.weight # [num_queries, c] 294 | text_embed = repeat(text_sentence_features, 'b c -> b t q c', t=t, q=self.num_queries) 295 | hs, memory, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, inter_samples = \ 296 | self.transformer(srcs, text_embed, masks, poses, query_embeds) 297 | # hs: [l, batch_size*time, num_queries_per_frame, c] 298 | # memory: list[Tensor], shape of tensor is [batch_size*time, c, hi, wi] 299 | # init_reference: [batch_size*time, num_queries_per_frame, 2] 300 | # inter_references: [l, batch_size*time, num_queries_per_frame, 4] 301 | 302 | out = {} 303 | # prediction 304 | outputs_classes = [] 305 | outputs_coords = [] 306 | for lvl in range(hs.shape[0]): 307 | if lvl == 0: 308 | reference = init_reference 309 | else: 310 | reference = inter_references[lvl - 1] 311 | reference = inverse_sigmoid(reference) 312 | outputs_class = self.class_embed[lvl](hs[lvl]) 313 | tmp = self.bbox_embed[lvl](hs[lvl]) 314 | if reference.shape[-1] == 4: 315 | tmp += reference 316 | else: 317 | assert reference.shape[-1] == 2 318 | tmp[..., :2] += reference 319 | outputs_coord = tmp.sigmoid() # cxcywh, range in [0,1] 320 | outputs_classes.append(outputs_class) 321 | outputs_coords.append(outputs_coord) 322 | outputs_class = torch.stack(outputs_classes) 323 | outputs_coord = torch.stack(outputs_coords) 324 | # rearrange 325 | outputs_class = rearrange(outputs_class, 'l (b t) q k -> l b t q k', b=b, t=t) 326 | outputs_coord = rearrange(outputs_coord, 'l (b t) q n -> l b t q n', b=b, t=t) 327 | out['pred_logits'] = outputs_class[-1] # [batch_size, time, num_queries_per_frame, num_classes] 328 | out['pred_boxes'] = outputs_coord[-1] # [batch_size, time, num_queries_per_frame, 4] 329 | 330 | # Segmentation 331 | mask_features = self.pixel_decoder(features, text_features, pos, memory, nf=t) # [batch_size*time, c, out_h, out_w] 332 | mask_features = rearrange(mask_features, '(b t) c h w -> b t c h w', b=b, t=t) 333 | 334 | # extra cls head 335 | if self.positive_cls: 336 | outputs_positives = [] 337 | for lvl in range(hs.shape[0]): 338 | outputs_positive = rearrange(hs[lvl], '(b t) q c -> b t q c', b=b, t=t) 339 | outputs_positive = torch.mean(outputs_positive,dim=(1,2)) 340 | outputs_positive = self.positive_cls_embed[lvl](outputs_positive) 341 | outputs_positives.append(outputs_positive) 342 | outputs_positives = torch.stack(outputs_positives) #l b 1 343 | #print(outputs_positives.shape) 344 | out['pred_positives'] = outputs_positives[-1] 345 | 346 | # dynamic conv 347 | outputs_seg_masks = [] 348 | for lvl in range(hs.shape[0]): 349 | dynamic_mask_head_params = self.controller(hs[lvl]) # [batch_size*time, num_queries_per_frame, num_params] 350 | dynamic_mask_head_params = rearrange(dynamic_mask_head_params, '(b t) q n -> b (t q) n', b=b, t=t) 351 | lvl_references = inter_references[lvl, ..., :2] 352 | lvl_references = rearrange(lvl_references, '(b t) q n -> b (t q) n', b=b, t=t) 353 | outputs_seg_mask = self.dynamic_mask_with_coords(mask_features, dynamic_mask_head_params, lvl_references, targets) 354 | outputs_seg_mask = rearrange(outputs_seg_mask, 'b (t q) h w -> b t q h w', t=t) 355 | outputs_seg_masks.append(outputs_seg_mask) 356 | out['pred_masks'] = outputs_seg_masks[-1] # [batch_size, time, num_queries_per_frame, out_h, out_w] 357 | 358 | if self.aux_loss: 359 | if self.positive_cls: 360 | out['aux_outputs'] = self._set_aux_loss_pos(outputs_class, outputs_coord, outputs_seg_masks, outputs_positives) 361 | else: 362 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_seg_masks) 363 | 364 | if not self.training: 365 | # for visualization 366 | inter_references = inter_references[-2, :, :, :2] # [batch_size*time, num_queries_per_frame, 2] 367 | inter_references = rearrange(inter_references, '(b t) q n -> b t q n', b=b, t=t) 368 | out['reference_points'] = inter_references # the reference points of last layer input 369 | return out 370 | 371 | @torch.jit.unused 372 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_seg_masks): 373 | # this is a workaround to make torchscript happy, as torchscript 374 | # doesn't support dictionary with non-homogeneous values, such 375 | # as a dict having both a Tensor and a list. 376 | return [{"pred_logits": a, "pred_boxes": b, "pred_masks": c} 377 | for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_seg_masks[:-1])] 378 | 379 | @torch.jit.unused 380 | def _set_aux_loss_pos(self, outputs_class, outputs_coord, outputs_seg_masks, outputs_positives): 381 | # this is a workaround to make torchscript happy, as torchscript 382 | # doesn't support dictionary with non-homogeneous values, such 383 | # as a dict having both a Tensor and a list. 384 | return [{"pred_logits": a, "pred_boxes": b, "pred_masks": c, "pred_positives": d} 385 | for a, b, c, d in zip(outputs_class[:-1], outputs_coord[:-1], outputs_seg_masks[:-1], outputs_positives[:-1])] 386 | 387 | def forward_text(self, captions, device): 388 | if isinstance(captions[0], str): 389 | tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt").to(device) 390 | encoded_text = self.text_encoder(**tokenized) 391 | # encoded_text.last_hidden_state: [batch_size, length, 768] 392 | # encoded_text.pooler_output: [batch_size, 768] 393 | text_attention_mask = tokenized.attention_mask.ne(1).bool() 394 | # text_attention_mask: [batch_size, length] 395 | 396 | text_features = encoded_text.last_hidden_state 397 | text_features = self.resizer(text_features) 398 | text_masks = text_attention_mask 399 | text_features = NestedTensor(text_features, text_masks) # NestedTensor 400 | 401 | text_sentence_features = encoded_text.pooler_output 402 | text_sentence_features = self.resizer(text_sentence_features) 403 | else: 404 | raise ValueError("Please mask sure the caption is a list of string") 405 | return text_features, text_sentence_features 406 | 407 | def dynamic_mask_with_coords(self, mask_features, mask_head_params, reference_points, targets): 408 | """ 409 | Add the relative coordinates to the mask_features channel dimension, 410 | and perform dynamic mask conv. 411 | 412 | Args: 413 | mask_features: [batch_size, time, c, h, w] 414 | mask_head_params: [batch_size, time * num_queries_per_frame, num_params] 415 | reference_points: [batch_size, time * num_queries_per_frame, 2], cxcy 416 | targets (list[dict]): length is batch size 417 | we need the key 'size' for computing location. 418 | Return: 419 | outputs_seg_mask: [batch_size, time * num_queries_per_frame, h, w] 420 | """ 421 | device = mask_features.device 422 | b, t, c, h, w = mask_features.shape 423 | # this is the total query number in all frames 424 | _, num_queries = reference_points.shape[:2] 425 | q = num_queries // t # num_queries_per_frame 426 | 427 | # prepare reference points in image size (the size is input size to the model) 428 | new_reference_points = [] 429 | for i in range(b): 430 | img_h, img_w = targets[i]['size'] 431 | scale_f = torch.stack([img_w, img_h], dim=0) 432 | tmp_reference_points = reference_points[i] * scale_f[None, :] 433 | new_reference_points.append(tmp_reference_points) 434 | new_reference_points = torch.stack(new_reference_points, dim=0) 435 | # [batch_size, time * num_queries_per_frame, 2], in image size 436 | reference_points = new_reference_points 437 | 438 | # prepare the mask features 439 | if self.rel_coord: 440 | reference_points = rearrange(reference_points, 'b (t q) n -> b t q n', t=t, q=q) 441 | locations = compute_locations(h, w, device=device, stride=self.mask_feat_stride) 442 | relative_coords = reference_points.reshape(b, t, q, 1, 1, 2) - \ 443 | locations.reshape(1, 1, 1, h, w, 2) # [batch_size, time, num_queries_per_frame, h, w, 2] 444 | relative_coords = relative_coords.permute(0, 1, 2, 5, 3, 4) # [batch_size, time, num_queries_per_frame, 2, h, w] 445 | 446 | # concat features 447 | mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w] 448 | mask_features = torch.cat([mask_features, relative_coords], dim=3) 449 | else: 450 | mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w] 451 | mask_features = mask_features.reshape(1, -1, h, w) 452 | 453 | # parse dynamic params 454 | mask_head_params = mask_head_params.flatten(0, 1) 455 | weights, biases = parse_dynamic_params( 456 | mask_head_params, self.dynamic_mask_channels, 457 | self.weight_nums, self.bias_nums 458 | ) 459 | 460 | # dynamic mask conv 461 | mask_logits = self.mask_heads_forward(mask_features, weights, biases, mask_head_params.shape[0]) 462 | mask_logits = mask_logits.reshape(-1, 1, h, w) 463 | 464 | # upsample predicted masks 465 | assert self.mask_feat_stride >= self.mask_out_stride 466 | assert self.mask_feat_stride % self.mask_out_stride == 0 467 | 468 | mask_logits = aligned_bilinear(mask_logits, int(self.mask_feat_stride / self.mask_out_stride)) 469 | mask_logits = mask_logits.reshape(b, num_queries, mask_logits.shape[-2], mask_logits.shape[-1]) 470 | 471 | return mask_logits # [batch_size, time * num_queries_per_frame, h, w] 472 | 473 | def mask_heads_forward(self, features, weights, biases, num_insts): 474 | ''' 475 | :param features 476 | :param weights: [w0, w1, ...] 477 | :param bias: [b0, b1, ...] 478 | :return: 479 | ''' 480 | assert features.dim() == 4 481 | n_layers = len(weights) 482 | x = features 483 | for i, (w, b) in enumerate(zip(weights, biases)): 484 | x = F.conv2d( 485 | x, w, bias=b, 486 | stride=1, padding=0, 487 | groups=num_insts 488 | ) 489 | if i < n_layers - 1: 490 | x = F.relu(x) 491 | return x 492 | 493 | 494 | def parse_dynamic_params(params, channels, weight_nums, bias_nums): 495 | assert params.dim() == 2 496 | assert len(weight_nums) == len(bias_nums) 497 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 498 | 499 | num_insts = params.size(0) 500 | num_layers = len(weight_nums) 501 | 502 | params_splits = list(torch.split_with_sizes(params, weight_nums + bias_nums, dim=1)) 503 | 504 | weight_splits = params_splits[:num_layers] 505 | bias_splits = params_splits[num_layers:] 506 | 507 | for l in range(num_layers): 508 | if l < num_layers - 1: 509 | # out_channels x in_channels x 1 x 1 510 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) 511 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 512 | else: 513 | # out_channels x in_channels x 1 x 1 514 | weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) 515 | bias_splits[l] = bias_splits[l].reshape(num_insts) 516 | 517 | return weight_splits, bias_splits 518 | 519 | def aligned_bilinear(tensor, factor): 520 | assert tensor.dim() == 4 521 | assert factor >= 1 522 | assert int(factor) == factor 523 | 524 | if factor == 1: 525 | return tensor 526 | 527 | h, w = tensor.size()[2:] 528 | tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") 529 | oh = factor * h + 1 530 | ow = factor * w + 1 531 | tensor = F.interpolate( 532 | tensor, size=(oh, ow), 533 | mode='bilinear', 534 | align_corners=True 535 | ) 536 | tensor = F.pad( 537 | tensor, pad=(factor // 2, 0, factor // 2, 0), 538 | mode="replicate" 539 | ) 540 | 541 | return tensor[:, :, :oh - 1, :ow - 1] 542 | 543 | 544 | def compute_locations(h, w, device, stride=1): 545 | shifts_x = torch.arange( 546 | 0, w * stride, step=stride, 547 | dtype=torch.float32, device=device) 548 | 549 | shifts_y = torch.arange( 550 | 0, h * stride, step=stride, 551 | dtype=torch.float32, device=device) 552 | 553 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 554 | shift_x = shift_x.reshape(-1) 555 | shift_y = shift_y.reshape(-1) 556 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 557 | return locations 558 | 559 | 560 | 561 | class MLP(nn.Module): 562 | """ Very simple multi-layer perceptron (also called FFN)""" 563 | 564 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 565 | super().__init__() 566 | self.num_layers = num_layers 567 | h = [hidden_dim] * (num_layers - 1) 568 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 569 | 570 | def forward(self, x): 571 | for i, layer in enumerate(self.layers): 572 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 573 | return x 574 | 575 | class FeatureResizer(nn.Module): 576 | """ 577 | This class takes as input a set of embeddings of dimension C1 and outputs a set of 578 | embedding of dimension C2, after a linear transformation, dropout and normalization (LN). 579 | """ 580 | 581 | def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): 582 | super().__init__() 583 | self.do_ln = do_ln 584 | # Object feature encoding 585 | self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) 586 | self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) 587 | self.dropout = nn.Dropout(dropout) 588 | 589 | def forward(self, encoder_features): 590 | x = self.fc(encoder_features) 591 | if self.do_ln: 592 | x = self.layer_norm(x) 593 | output = self.dropout(x) 594 | return output 595 | 596 | 597 | def build(args): 598 | if args.binary: 599 | num_classes = 1 600 | else: 601 | if args.dataset_file == 'ytvos': 602 | num_classes = 65 603 | elif args.dataset_file == 'davis': 604 | num_classes = 78 605 | elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb': 606 | num_classes = 1 607 | else: 608 | num_classes = 91 # for coco 609 | device = torch.device(args.device) 610 | 611 | # backbone 612 | if 'video_swin' in args.backbone: 613 | from .video_swin_transformer import build_video_swin_backbone 614 | backbone = build_video_swin_backbone(args) 615 | elif 'swin' in args.backbone: 616 | from .swin_transformer import build_swin_backbone 617 | backbone = build_swin_backbone(args) 618 | else: 619 | backbone = build_backbone(args) 620 | 621 | transformer = build_deforamble_transformer(args) 622 | 623 | model = ReferFormer( 624 | backbone, 625 | transformer, 626 | num_classes=num_classes, 627 | num_queries=args.num_queries, 628 | num_feature_levels=args.num_feature_levels, 629 | num_frames=args.num_frames, 630 | mask_dim=args.mask_dim, 631 | dim_feedforward=args.dim_feedforward, 632 | controller_layers=args.controller_layers, 633 | dynamic_mask_channels=args.dynamic_mask_channels, 634 | aux_loss=args.aux_loss, 635 | with_box_refine=args.with_box_refine, 636 | two_stage=args.two_stage, 637 | freeze_text_encoder=args.freeze_text_encoder, 638 | rel_coord=args.rel_coord, 639 | positive_cls=args.use_positive_cls 640 | ) 641 | matcher = build_matcher(args) 642 | weight_dict = {} 643 | weight_dict['loss_ce'] = args.cls_loss_coef 644 | weight_dict['loss_bbox'] = args.bbox_loss_coef 645 | weight_dict['loss_giou'] = args.giou_loss_coef 646 | if args.masks: # always true 647 | weight_dict['loss_mask'] = args.mask_loss_coef 648 | weight_dict['loss_dice'] = args.dice_loss_coef 649 | if args.use_positive_cls: # an extra cls head for positive classification 650 | weight_dict['loss_positive_labels'] = args.cls_loss_coef 651 | # TODO this is a hack 652 | if args.aux_loss: 653 | aux_weight_dict = {} 654 | for i in range(args.dec_layers - 1): 655 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 656 | weight_dict.update(aux_weight_dict) 657 | 658 | losses = ['labels', 'boxes'] 659 | if args.use_positive_cls: 660 | losses += ['positive_labels'] 661 | if args.masks: 662 | if args.use_weights: 663 | losses += ['weighted_masks'] 664 | else: 665 | losses += ['masks'] 666 | criterion = SetCriterion( 667 | num_classes, 668 | matcher=matcher, 669 | weight_dict=weight_dict, 670 | eos_coef=args.eos_coef, 671 | losses=losses, 672 | focal_alpha=args.focal_alpha) 673 | criterion.to(device) 674 | 675 | # postprocessors, this is used for coco pretrain but not for rvos 676 | postprocessors = build_postprocessors(args, args.dataset_file) 677 | return model, criterion, postprocessors 678 | 679 | 680 | 681 | -------------------------------------------------------------------------------- /RF_ActionVOS/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Segmentaion Part 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | from collections import defaultdict 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | from PIL import Image 13 | 14 | from einops import rearrange, repeat 15 | 16 | try: 17 | from panopticapi.utils import id2rgb, rgb2id 18 | except ImportError: 19 | pass 20 | 21 | import fvcore.nn.weight_init as weight_init 22 | 23 | from .position_encoding import PositionEmbeddingSine1D 24 | 25 | BN_MOMENTUM = 0.1 26 | 27 | def get_norm(norm, out_channels): # only support GN or LN 28 | """ 29 | Args: 30 | norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; 31 | or a callable that takes a channel number and returns 32 | the normalization layer as a nn.Module. 33 | 34 | Returns: 35 | nn.Module or None: the normalization layer 36 | """ 37 | if norm is None: 38 | return None 39 | if isinstance(norm, str): 40 | if len(norm) == 0: 41 | return None 42 | norm = { 43 | "GN": lambda channels: nn.GroupNorm(8, channels), 44 | "LN": lambda channels: nn.LayerNorm(channels) 45 | }[norm] 46 | return norm(out_channels) 47 | 48 | class Conv2d(torch.nn.Conv2d): 49 | """ 50 | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. 51 | """ 52 | 53 | def __init__(self, *args, **kwargs): 54 | """ 55 | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: 56 | 57 | Args: 58 | norm (nn.Module, optional): a normalization layer 59 | activation (callable(Tensor) -> Tensor): a callable activation function 60 | 61 | It assumes that norm layer is used before activation. 62 | """ 63 | norm = kwargs.pop("norm", None) 64 | activation = kwargs.pop("activation", None) 65 | super().__init__(*args, **kwargs) 66 | 67 | self.norm = norm 68 | self.activation = activation 69 | 70 | def forward(self, x): 71 | # torchscript does not support SyncBatchNorm yet 72 | # https://github.com/pytorch/pytorch/issues/40507 73 | # and we skip these codes in torchscript since: 74 | # 1. currently we only support torchscript in evaluation mode 75 | # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or 76 | # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. 77 | if not torch.jit.is_scripting(): 78 | if x.numel() == 0 and self.training: 79 | # https://github.com/pytorch/pytorch/issues/12013 80 | assert not isinstance( 81 | self.norm, torch.nn.SyncBatchNorm 82 | ), "SyncBatchNorm does not support empty inputs!" 83 | 84 | x = F.conv2d( 85 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups 86 | ) 87 | if self.norm is not None: 88 | x = self.norm(x) 89 | if self.activation is not None: 90 | x = self.activation(x) 91 | return x 92 | 93 | # FPN structure 94 | class CrossModalFPNDecoder(nn.Module): 95 | def __init__(self, feature_channels: List, conv_dim: int, mask_dim: int, dim_feedforward: int = 2048, norm=None): 96 | """ 97 | Args: 98 | feature_channels: list of fpn feature channel numbers. 99 | conv_dim: number of output channels for the intermediate conv layers. 100 | mask_dim: number of output channels for the final conv layer. 101 | dim_feedforward: number of vision-language fusion module ffn channel numbers. 102 | norm (str or callable): normalization for all conv layers 103 | """ 104 | super().__init__() 105 | 106 | self.feature_channels = feature_channels 107 | 108 | lateral_convs = [] 109 | output_convs = [] 110 | 111 | use_bias = norm == "" 112 | for idx, in_channels in enumerate(feature_channels): 113 | # in_channels: 4x -> 32x 114 | lateral_norm = get_norm(norm, conv_dim) 115 | output_norm = get_norm(norm, conv_dim) 116 | 117 | lateral_conv = Conv2d( 118 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm 119 | ) 120 | output_conv = Conv2d( 121 | conv_dim, 122 | conv_dim, 123 | kernel_size=3, 124 | stride=1, 125 | padding=1, 126 | bias=use_bias, 127 | norm=output_norm, 128 | activation=F.relu, 129 | ) 130 | weight_init.c2_xavier_fill(lateral_conv) 131 | weight_init.c2_xavier_fill(output_conv) 132 | stage = idx+1 133 | self.add_module("adapter_{}".format(stage), lateral_conv) 134 | self.add_module("layer_{}".format(stage), output_conv) 135 | 136 | lateral_convs.append(lateral_conv) 137 | output_convs.append(output_conv) 138 | 139 | # Place convs into top-down order (from low to high resolution) 140 | # to make the top-down computation in forward clearer. 141 | self.lateral_convs = lateral_convs[::-1] 142 | self.output_convs = output_convs[::-1] 143 | 144 | self.mask_dim = mask_dim 145 | self.mask_features = Conv2d( 146 | conv_dim, 147 | mask_dim, 148 | kernel_size=3, 149 | stride=1, 150 | padding=1, 151 | ) 152 | weight_init.c2_xavier_fill(self.mask_features) 153 | 154 | # vision-language cross-modal fusion 155 | self.text_pos = PositionEmbeddingSine1D(conv_dim, normalize=True) 156 | sr_ratios = [8, 4, 2, 1] 157 | cross_attns = [] 158 | for idx in range(len(feature_channels)): # res2 -> res5 159 | cross_attn = VisionLanguageBlock(conv_dim, dim_feedforward=dim_feedforward, 160 | nhead=8, sr_ratio=sr_ratios[idx]) 161 | for p in cross_attn.parameters(): 162 | if p.dim() > 1: 163 | nn.init.xavier_uniform_(p) 164 | stage = int(idx + 1) 165 | self.add_module("cross_attn_{}".format(stage), cross_attn) 166 | cross_attns.append(cross_attn) 167 | # place cross-attn in top-down order (from low to high resolution) 168 | self.cross_attns = cross_attns[::-1] 169 | 170 | 171 | def forward_features(self, features, text_features, poses, memory, nf): 172 | # nf: num_frames 173 | text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c] 174 | text_features, text_masks = text_features.decompose() 175 | text_features = text_features.permute(1, 0, 2) 176 | 177 | for idx, (mem, f, pos) in enumerate(zip(memory[::-1], features[1:][::-1], poses[1:][::-1])): # 32x -> 8x 178 | lateral_conv = self.lateral_convs[idx] 179 | output_conv = self.output_convs[idx] 180 | cross_attn = self.cross_attns[idx] 181 | 182 | _, x_mask = f.decompose() 183 | n, c, h, w = pos.shape 184 | b = n // nf 185 | t = nf 186 | 187 | # NOTE: here the (h, w) is the size for current fpn layer 188 | vision_features = lateral_conv(mem) # [b*t, c, h, w] 189 | vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t) 190 | vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t) 191 | vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t) 192 | 193 | cur_fpn = cross_attn(tgt=vision_features, 194 | memory=text_features, 195 | t=t, h=h, w=w, 196 | tgt_key_padding_mask=vision_masks, 197 | memory_key_padding_mask=text_masks, 198 | pos=text_pos, 199 | query_pos=vision_pos 200 | ) # [t*h*w, b, c] 201 | cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 202 | 203 | # upsample 204 | if idx == 0: # top layer 205 | y = output_conv(cur_fpn) 206 | else: 207 | # Following FPN implementation, we use nearest upsampling here 208 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 209 | y = output_conv(y) 210 | 211 | # 4x level 212 | lateral_conv = self.lateral_convs[-1] 213 | output_conv = self.output_convs[-1] 214 | cross_attn = self.cross_attns[-1] 215 | 216 | x, x_mask = features[0].decompose() 217 | pos = poses[0] 218 | n, c, h, w = pos.shape 219 | b = n // nf 220 | t = nf 221 | 222 | vision_features = lateral_conv(x) # [b*t, c, h, w] 223 | vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t) 224 | vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t) 225 | vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t) 226 | 227 | cur_fpn = cross_attn(tgt=vision_features, 228 | memory=text_features, 229 | t=t, h=h, w=w, 230 | tgt_key_padding_mask=vision_masks, 231 | memory_key_padding_mask=text_masks, 232 | pos=text_pos, 233 | query_pos=vision_pos 234 | ) # [t*h*w, b, c] 235 | cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 236 | # Following FPN implementation, we use nearest upsampling here 237 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 238 | y = output_conv(y) 239 | return y # [b*t, c, h, w], the spatial stride is 4x 240 | 241 | def forward(self, features, text_features, pos, memory, nf): 242 | """The forward function receives the vision and language features, 243 | and outputs the mask features with the spatial stride of 4x. 244 | 245 | Args: 246 | features (list[NestedTensor]): backbone features (vision), length is number of FPN layers 247 | tensors: [b*t, ci, hi, wi], mask: [b*t, hi, wi] 248 | text_features (NestedTensor): text features (language) 249 | tensors: [b, length, c], mask: [b, length] 250 | pos (list[Tensor]): position encoding of vision features, length is number of FPN layers 251 | tensors: [b*t, c, hi, wi] 252 | memory (list[Tensor]): features from encoder output. from 8x -> 32x 253 | NOTE: the layer orders of both features and pos are res2 -> res5 254 | 255 | Returns: 256 | mask_features (Tensor): [b*t, mask_dim, h, w], with the spatial stride of 4x. 257 | """ 258 | y = self.forward_features(features, text_features, pos, memory, nf) 259 | return self.mask_features(y) 260 | 261 | 262 | class VisionLanguageBlock(nn.Module): 263 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 264 | activation="relu", normalize_before=False, sr_ratio=1): 265 | super().__init__() 266 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 267 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 268 | # Implementation of Feedforward model 269 | self.linear1 = nn.Linear(d_model, dim_feedforward) 270 | self.dropout = nn.Dropout(dropout) 271 | self.linear2 = nn.Linear(dim_feedforward, d_model) 272 | 273 | self.norm1 = nn.LayerNorm(d_model) 274 | self.norm2 = nn.LayerNorm(d_model) 275 | self.norm3 = nn.LayerNorm(d_model) 276 | self.dropout1 = nn.Dropout(dropout) 277 | self.dropout2 = nn.Dropout(dropout) 278 | self.dropout3 = nn.Dropout(dropout) 279 | 280 | self.activation = _get_activation_fn(activation) 281 | self.normalize_before = normalize_before 282 | 283 | # for downsample 284 | self.sr_ratio = sr_ratio 285 | 286 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 287 | return tensor if pos is None else tensor + pos 288 | 289 | def forward_post(self, tgt, memory, t, h, w, 290 | tgt_key_padding_mask: Optional[Tensor] = None, 291 | memory_key_padding_mask: Optional[Tensor] = None, 292 | pos: Optional[Tensor] = None, 293 | query_pos: Optional[Tensor] = None): 294 | b = tgt.size(1) 295 | # self attn 296 | q = k = self.with_pos_embed(tgt, query_pos) 297 | if self.sr_ratio > 1: # downsample 298 | q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 299 | k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 300 | v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 301 | # downsample 302 | new_h = int(h * 1./self.sr_ratio) 303 | new_w = int(w * 1./self.sr_ratio) 304 | size = (new_h, new_w) 305 | q = F.interpolate(q, size=size, mode='nearest') 306 | k = F.interpolate(k, size=size, mode='nearest') 307 | v = F.interpolate(v, size=size, mode='nearest') 308 | # shape for transformer 309 | q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t) 310 | k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t) 311 | v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t) 312 | # downsample mask 313 | tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w) 314 | tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0] 315 | tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1) 316 | else: 317 | v = tgt 318 | tgt2 = self.self_attn(q, k, value=v, attn_mask=None, 319 | key_padding_mask=tgt_key_padding_mask)[0] # [H*W, B*T, C] 320 | if self.sr_ratio > 1: 321 | tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w) 322 | size = (h, w) # recover to origin size 323 | tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W] 324 | tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t) 325 | tgt = tgt + self.dropout1(tgt2) 326 | tgt = self.norm1(tgt) 327 | 328 | # cross attn 329 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 330 | key=self.with_pos_embed(memory, pos), 331 | value=memory, attn_mask=None, 332 | key_padding_mask=memory_key_padding_mask)[0] 333 | tgt = tgt + self.dropout2(tgt2) 334 | tgt = self.norm2(tgt) 335 | 336 | # ffn 337 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 338 | tgt = tgt + self.dropout3(tgt2) 339 | tgt = self.norm3(tgt) 340 | return tgt 341 | 342 | def forward_pre(self, tgt, memory, t, h, w, 343 | tgt_key_padding_mask: Optional[Tensor] = None, 344 | memory_key_padding_mask: Optional[Tensor] = None, 345 | pos: Optional[Tensor] = None, 346 | query_pos: Optional[Tensor] = None): 347 | b = tgt.size(1) 348 | # self attn 349 | tgt2 = self.norm1(tgt) 350 | q = k = self.with_pos_embed(tgt2, query_pos) 351 | if self.sr_ratio > 1: # downsample 352 | q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 353 | k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 354 | v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 355 | # downsample 356 | new_h = int(h * 1./self.sr_ratio) 357 | new_w = int(w * 1./self.sr_ratio) 358 | size = (new_h, new_w) 359 | q = F.interpolate(q, size=size, mode='nearest') 360 | k = F.interpolate(k, size=size, mode='nearest') 361 | v = F.interpolate(v, size=size, mode='nearest') 362 | # shape for transformer 363 | q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t) 364 | k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t) 365 | v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t) 366 | # downsample mask 367 | tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w) 368 | tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0] 369 | tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1) 370 | else: 371 | v = tgt2 372 | tgt2 = self.self_attn(q, k, value=v, attn_mask=None, 373 | key_padding_mask=tgt_key_padding_mask)[0] # [T*H*W, B, C] 374 | if self.sr_ratio > 1: 375 | tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w) 376 | size = (h, w) # recover to origin size 377 | tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W] 378 | tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t) 379 | tgt = tgt + self.dropout1(tgt2) 380 | 381 | # cross attn 382 | tgt2 = self.norm2(tgt) 383 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 384 | key=self.with_pos_embed(memory, pos), 385 | value=memory, attn_mask=None, 386 | key_padding_mask=memory_key_padding_mask)[0] 387 | tgt = tgt + self.dropout2(tgt2) 388 | 389 | # ffn 390 | tgt2 = self.norm3(tgt) 391 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 392 | tgt = tgt + self.dropout3(tgt2) 393 | return tgt 394 | 395 | def forward(self, tgt, memory, t, h, w, 396 | tgt_key_padding_mask: Optional[Tensor] = None, 397 | memory_key_padding_mask: Optional[Tensor] = None, 398 | pos: Optional[Tensor] = None, 399 | query_pos: Optional[Tensor] = None): 400 | if self.normalize_before: 401 | return self.forward_pre(tgt, memory, t, h, w, 402 | tgt_key_padding_mask, memory_key_padding_mask, 403 | pos, query_pos) 404 | return self.forward_post(tgt, memory, t, h, w, 405 | tgt_key_padding_mask, memory_key_padding_mask, 406 | pos, query_pos) 407 | 408 | 409 | 410 | class VisionLanguageFusionModule(nn.Module): 411 | def __init__(self, d_model, nhead, dropout=0.0): 412 | super().__init__() 413 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 414 | 415 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 416 | return tensor if pos is None else tensor + pos 417 | 418 | def forward(self, tgt, memory, 419 | memory_key_padding_mask: Optional[Tensor] = None, 420 | pos: Optional[Tensor] = None, 421 | query_pos: Optional[Tensor] = None): 422 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 423 | key=self.with_pos_embed(memory, pos), 424 | value=memory, attn_mask=None, 425 | key_padding_mask=memory_key_padding_mask)[0] 426 | tgt = tgt * tgt2 427 | return tgt 428 | 429 | 430 | def dice_loss(inputs, targets, num_boxes): 431 | """ 432 | Compute the DICE loss, similar to generalized IOU for masks 433 | Args: 434 | inputs: A float tensor of arbitrary shape. 435 | The predictions for each example. 436 | targets: A float tensor with the same shape as inputs. Stores the binary 437 | classification label for each element in inputs 438 | (0 for the negative class and 1 for the positive class). 439 | """ 440 | inputs = inputs.sigmoid() 441 | inputs = inputs.flatten(1) 442 | numerator = 2 * (inputs * targets).sum(1) 443 | denominator = inputs.sum(-1) + targets.sum(-1) 444 | loss = 1 - (numerator + 1) / (denominator + 1) 445 | return loss.sum() / num_boxes 446 | 447 | 448 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 449 | """ 450 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 451 | Args: 452 | inputs: A float tensor of arbitrary shape. 453 | The predictions for each example. 454 | targets: A float tensor with the same shape as inputs. Stores the binary 455 | classification label for each element in inputs 456 | (0 for the negative class and 1 for the positive class). 457 | alpha: (optional) Weighting factor in range (0,1) to balance 458 | positive vs negative examples. Default = -1 (no weighting). 459 | gamma: Exponent of the modulating factor (1 - p_t) to 460 | balance easy vs hard examples. 461 | Returns: 462 | Loss tensor 463 | """ 464 | prob = inputs.sigmoid() 465 | #print('w/o weight', inputs.shape) 466 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 467 | p_t = prob * targets + (1 - prob) * (1 - targets) 468 | loss = ce_loss * ((1 - p_t) ** gamma) 469 | 470 | if alpha >= 0: 471 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 472 | loss = alpha_t * loss 473 | 474 | return loss.mean(1).sum() / num_boxes 475 | 476 | def sigmoid_focal_loss_weighted(inputs, targets, num_boxes, weights, alpha: float = 0.25, gamma: float = 2): 477 | """ 478 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 479 | Args: 480 | inputs: A float tensor of arbitrary shape. 481 | The predictions for each example. 482 | targets: A float tensor with the same shape as inputs. Stores the binary 483 | classification label for each element in inputs 484 | (0 for the negative class and 1 for the positive class). 485 | alpha: (optional) Weighting factor in range (0,1) to balance 486 | positive vs negative examples. Default = -1 (no weighting). 487 | gamma: Exponent of the modulating factor (1 - p_t) to 488 | balance easy vs hard examples. 489 | Returns: 490 | Loss tensor 491 | """ 492 | prob = inputs.sigmoid() 493 | #print('w/ weight', weights.shape, torch.unique(weights)) 494 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=weights) 495 | p_t = prob * targets + (1 - prob) * (1 - targets) 496 | loss = ce_loss * ((1 - p_t) ** gamma) 497 | 498 | if alpha >= 0: 499 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 500 | loss = alpha_t * loss 501 | 502 | return loss.mean(1).sum() / num_boxes 503 | 504 | def _get_activation_fn(activation): 505 | """Return an activation function given a string""" 506 | if activation == "relu": 507 | return F.relu 508 | if activation == "gelu": 509 | return F.gelu 510 | if activation == "glu": 511 | return F.glu 512 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 513 | 514 | 515 | -------------------------------------------------------------------------------- /RF_ActionVOS/test_actionvos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | free_gpu=$3 3 | export CUDA_VISIBLE_DEVICES=$free_gpu 4 | set -x 5 | 6 | GPUS=${GPUS:-1} 7 | PORT=${PORT:-$4} 8 | if [ $GPUS -lt 1 ]; then 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 10 | else 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-1} 12 | fi 13 | CPUS_PER_TASK=${CPUS_PER_TASK:-1} 14 | 15 | OUTPUT_DIR=$1 16 | CHECKPOINT=$2 17 | PY_ARGS=${@:5} # Any arguments from the forth one are captured by this 18 | 19 | echo "Load model weights from: ${CHECKPOINT}" 20 | 21 | python3 inference_actionvos.py --with_box_refine --binary --freeze_text_encoder \ 22 | --output_dir=${OUTPUT_DIR} --resume=${CHECKPOINT} ${PY_ARGS} 23 | 24 | echo "Working path is: ${OUTPUT_DIR}" 25 | -------------------------------------------------------------------------------- /RF_ActionVOS/train_actionvos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | num_gpu=$3 3 | free_gpu=$4 4 | export CUDA_VISIBLE_DEVICES=$free_gpu 5 | set -x 6 | 7 | GPUS=${GPUS:-$num_gpu} 8 | PORT=${PORT:-$5} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-$num_gpu} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | 12 | OUTPUT_DIR=$1 13 | PRETRAINED_WEIGHTS=$2 14 | PY_ARGS=${@:6} # Any arguments from the six one are captured by this 15 | 16 | echo "Load pretrained weights from: ${PRETRAINED_WEIGHTS}" 17 | 18 | # train 19 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 20 | python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --master_port=${PORT} --use_env \ 21 | main_actionvos.py --with_box_refine --binary --freeze_text_encoder \ 22 | --output_dir=${OUTPUT_DIR} --pretrained_weights=${PRETRAINED_WEIGHTS} ${PY_ARGS} 23 | 24 | echo "Working path is: ${OUTPUT_DIR}" 25 | 26 | -------------------------------------------------------------------------------- /RF_ActionVOS/transforms_video_actionvos.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms and data augmentation for sequence level images, bboxes and masks. 3 | """ 4 | import random 5 | 6 | import PIL 7 | import torch 8 | import torchvision.transforms as T 9 | import torchvision.transforms.functional as F 10 | 11 | from util.box_ops import box_xyxy_to_cxcywh, box_iou 12 | from util.misc import interpolate 13 | import numpy as np 14 | from numpy import random as rand 15 | from PIL import Image 16 | import cv2 17 | 18 | 19 | 20 | class Check(object): 21 | def __init__(self,): 22 | pass 23 | def __call__(self, img, target): 24 | fields = ["labels"] 25 | if "boxes" in target: 26 | fields.append("boxes") 27 | if "masks" in target: 28 | fields.append("masks") 29 | 30 | ### check if box or mask still exist after transforms 31 | if "boxes" in target or "masks" in target: 32 | if "boxes" in target: 33 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 34 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 35 | else: 36 | keep = target['masks'].flatten(1).any(1) 37 | 38 | if False in keep: 39 | for k in range(len(keep)): 40 | if not keep[k] and "boxes" in target: 41 | target['boxes'][k] = target['boxes'][k]//1000.0 # [0, 0, 0, 0] 42 | target['valid'] = keep.to(torch.int32) 43 | 44 | return img, target 45 | 46 | 47 | 48 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6): 49 | assert mode in ['iou', 'iof'] 50 | bboxes1 = bboxes1.astype(np.float32) 51 | bboxes2 = bboxes2.astype(np.float32) 52 | rows = bboxes1.shape[0] 53 | cols = bboxes2.shape[0] 54 | ious = np.zeros((rows, cols), dtype=np.float32) 55 | if rows * cols == 0: 56 | return ious 57 | exchange = False 58 | if bboxes1.shape[0] > bboxes2.shape[0]: 59 | bboxes1, bboxes2 = bboxes2, bboxes1 60 | ious = np.zeros((cols, rows), dtype=np.float32) 61 | exchange = True 62 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) 63 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) 64 | for i in range(bboxes1.shape[0]): 65 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 66 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 67 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 68 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 69 | overlap = np.maximum(x_end - x_start, 0) * np.maximum(y_end - y_start, 0) 70 | if mode == 'iou': 71 | union = area1[i] + area2 - overlap 72 | else: 73 | union = area1[i] if not exchange else area2 74 | union = np.maximum(union, eps) 75 | ious[i, :] = overlap / union 76 | if exchange: 77 | ious = ious.T 78 | return ious 79 | 80 | 81 | def crop(clip, target, region): 82 | cropped_image = [] 83 | for image in clip: 84 | cropped_image.append(F.crop(image, *region)) 85 | 86 | target = target.copy() 87 | i, j, h, w = region 88 | 89 | # should we do something wrt the original size? 90 | target["size"] = torch.tensor([h, w]) 91 | 92 | fields = ["labels", "area", "iscrowd"] 93 | 94 | if "boxes" in target: 95 | boxes = target["boxes"] 96 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 97 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 98 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 99 | cropped_boxes = cropped_boxes.clamp(min=0) 100 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 101 | target["boxes"] = cropped_boxes.reshape(-1, 4) 102 | target["area"] = area 103 | fields.append("boxes") 104 | 105 | if "masks" in target: 106 | # FIXME should we update the area here if there are no boxes? 107 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 108 | fields.append("masks") 109 | 110 | if "weights" in target: 111 | # FIXME should we update the area here if there are no boxes? 112 | target['weights'] = target['weights'][:, i:i + h, j:j + w] 113 | fields.append("weights") 114 | 115 | return cropped_image, target 116 | 117 | 118 | def hflip(clip, target): 119 | flipped_image = [] 120 | for image in clip: 121 | flipped_image.append(F.hflip(image)) 122 | 123 | w, h = clip[0].size 124 | 125 | target = target.copy() 126 | if "boxes" in target: 127 | boxes = target["boxes"] 128 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 129 | target["boxes"] = boxes 130 | 131 | if "masks" in target: 132 | target['masks'] = target['masks'].flip(-1) 133 | 134 | if "weights" in target: 135 | target['weights'] = target['weights'].flip(-1) 136 | 137 | return flipped_image, target 138 | 139 | def vflip(image,target): 140 | flipped_image = [] 141 | for image in clip: 142 | flipped_image.append(F.vflip(image)) 143 | w, h = clip[0].size 144 | target = target.copy() 145 | if "boxes" in target: 146 | boxes = target["boxes"] 147 | boxes = boxes[:, [0, 3, 2, 1]] * torch.as_tensor([1, -1, 1, -1]) + torch.as_tensor([0, h, 0, h]) 148 | target["boxes"] = boxes 149 | 150 | if "masks" in target: 151 | target['masks'] = target['masks'].flip(1) 152 | 153 | if "weights" in target: 154 | target['weights'] = target['weights'].flip(1) 155 | 156 | return flipped_image, target 157 | 158 | def resize(clip, target, size, max_size=None): 159 | # size can be min_size (scalar) or (w, h) tuple 160 | 161 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 162 | w, h = image_size 163 | if max_size is not None: 164 | min_original_size = float(min((w, h))) 165 | max_original_size = float(max((w, h))) 166 | if max_original_size / min_original_size * size > max_size: 167 | size = int(round(max_size * min_original_size / max_original_size)) 168 | 169 | if (w <= h and w == size) or (h <= w and h == size): 170 | return (h, w) 171 | 172 | if w < h: 173 | ow = size 174 | oh = int(size * h / w) 175 | else: 176 | oh = size 177 | ow = int(size * w / h) 178 | 179 | return (oh, ow) 180 | 181 | def get_size(image_size, size, max_size=None): 182 | if isinstance(size, (list, tuple)): 183 | return size[::-1] 184 | else: 185 | return get_size_with_aspect_ratio(image_size, size, max_size) 186 | 187 | size = get_size(clip[0].size, size, max_size) 188 | rescaled_image = [] 189 | for image in clip: 190 | rescaled_image.append(F.resize(image, size)) 191 | 192 | if target is None: 193 | return rescaled_image, None 194 | 195 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image[0].size, clip[0].size)) 196 | ratio_width, ratio_height = ratios 197 | 198 | target = target.copy() 199 | if "boxes" in target: 200 | boxes = target["boxes"] 201 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 202 | target["boxes"] = scaled_boxes 203 | 204 | if "area" in target: 205 | area = target["area"] 206 | scaled_area = area * (ratio_width * ratio_height) 207 | target["area"] = scaled_area 208 | 209 | h, w = size 210 | target["size"] = torch.tensor([h, w]) 211 | 212 | if "masks" in target: 213 | if target['masks'].shape[0]>0: 214 | target['masks'] = interpolate( 215 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 216 | else: 217 | target['masks'] = torch.zeros((target['masks'].shape[0],h,w)) 218 | 219 | if "weights" in target: 220 | if target['weights'].shape[0]>0: 221 | target['weights'] = interpolate( 222 | target['weights'][:, None].float(), size, mode="nearest")[:, 0] 223 | else: 224 | target['weights'] = torch.ones((target['weights'].shape[0],h,w)) 225 | return rescaled_image, target 226 | 227 | 228 | def pad(clip, target, padding): 229 | # assumes that we only pad on the bottom right corners 230 | padded_image = [] 231 | for image in clip: 232 | padded_image.append(F.pad(image, (0, 0, padding[0], padding[1]))) 233 | if target is None: 234 | return padded_image, None 235 | target = target.copy() 236 | # should we do something wrt the original size? 237 | target["size"] = torch.tensor(padded_image[0].size[::-1]) 238 | if "masks" in target: 239 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 240 | if "weights" in target: 241 | target['weights'] = torch.nn.functional.pad(target['weights'], (0, padding[0], 0, padding[1]), value=1) 242 | return padded_image, target 243 | 244 | 245 | class RandomCrop(object): 246 | def __init__(self, size): 247 | self.size = size 248 | 249 | def __call__(self, img, target): 250 | region = T.RandomCrop.get_params(img, self.size) 251 | return crop(img, target, region) 252 | 253 | 254 | class RandomSizeCrop(object): 255 | def __init__(self, min_size: int, max_size: int): 256 | self.min_size = min_size 257 | self.max_size = max_size 258 | 259 | def __call__(self, img: PIL.Image.Image, target: dict): 260 | w = random.randint(self.min_size, min(img[0].width, self.max_size)) 261 | h = random.randint(self.min_size, min(img[0].height, self.max_size)) 262 | region = T.RandomCrop.get_params(img[0], [h, w]) 263 | return crop(img, target, region) 264 | 265 | 266 | class CenterCrop(object): 267 | def __init__(self, size): 268 | self.size = size 269 | 270 | def __call__(self, img, target): 271 | image_width, image_height = img.size 272 | crop_height, crop_width = self.size 273 | crop_top = int(round((image_height - crop_height) / 2.)) 274 | crop_left = int(round((image_width - crop_width) / 2.)) 275 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 276 | 277 | 278 | class MinIoURandomCrop(object): 279 | def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3): 280 | self.min_ious = min_ious 281 | self.sample_mode = (1, *min_ious, 0) 282 | self.min_crop_size = min_crop_size 283 | 284 | def __call__(self, img, target): 285 | w,h = img.size 286 | while True: 287 | mode = random.choice(self.sample_mode) 288 | self.mode = mode 289 | if mode == 1: 290 | return img,target 291 | min_iou = mode 292 | boxes = target['boxes'].numpy() 293 | labels = target['labels'] 294 | 295 | for i in range(50): 296 | new_w = rand.uniform(self.min_crop_size * w, w) 297 | new_h = rand.uniform(self.min_crop_size * h, h) 298 | if new_h / new_w < 0.5 or new_h / new_w > 2: 299 | continue 300 | left = rand.uniform(w - new_w) 301 | top = rand.uniform(h - new_h) 302 | patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h))) 303 | if patch[2] == patch[0] or patch[3] == patch[1]: 304 | continue 305 | overlaps = bbox_overlaps(patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) 306 | if len(overlaps) > 0 and overlaps.min() < min_iou: 307 | continue 308 | 309 | if len(overlaps) > 0: 310 | def is_center_of_bboxes_in_patch(boxes, patch): 311 | center = (boxes[:, :2] + boxes[:, 2:]) / 2 312 | mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (center[:, 1] < patch[3])) 313 | return mask 314 | mask = is_center_of_bboxes_in_patch(boxes, patch) 315 | if False in mask: 316 | continue 317 | #TODO: use no center boxes 318 | #if not mask.any(): 319 | # continue 320 | 321 | boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) 322 | boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) 323 | boxes -= np.tile(patch[:2], 2) 324 | target['boxes'] = torch.tensor(boxes) 325 | 326 | img = np.asarray(img)[patch[1]:patch[3], patch[0]:patch[2]] 327 | img = Image.fromarray(img) 328 | width, height = img.size 329 | target['orig_size'] = torch.tensor([height,width]) 330 | target['size'] = torch.tensor([height,width]) 331 | return img,target 332 | 333 | 334 | class RandomContrast(object): 335 | def __init__(self, lower=0.5, upper=1.5): 336 | self.lower = lower 337 | self.upper = upper 338 | assert self.upper >= self.lower, "contrast upper must be >= lower." 339 | assert self.lower >= 0, "contrast lower must be non-negative." 340 | def __call__(self, image, target): 341 | 342 | if rand.randint(2): 343 | alpha = rand.uniform(self.lower, self.upper) 344 | image *= alpha 345 | return image, target 346 | 347 | class RandomBrightness(object): 348 | def __init__(self, delta=32): 349 | assert delta >= 0.0 350 | assert delta <= 255.0 351 | self.delta = delta 352 | def __call__(self, image, target): 353 | if rand.randint(2): 354 | delta = rand.uniform(-self.delta, self.delta) 355 | image += delta 356 | return image, target 357 | 358 | class RandomSaturation(object): 359 | def __init__(self, lower=0.5, upper=1.5): 360 | self.lower = lower 361 | self.upper = upper 362 | assert self.upper >= self.lower, "contrast upper must be >= lower." 363 | assert self.lower >= 0, "contrast lower must be non-negative." 364 | 365 | def __call__(self, image, target): 366 | if rand.randint(2): 367 | image[:, :, 1] *= rand.uniform(self.lower, self.upper) 368 | return image, target 369 | 370 | class RandomHue(object): # 371 | def __init__(self, delta=18.0): 372 | assert delta >= 0.0 and delta <= 360.0 373 | self.delta = delta 374 | 375 | def __call__(self, image, target): 376 | if rand.randint(2): 377 | image[:, :, 0] += rand.uniform(-self.delta, self.delta) 378 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 379 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 380 | return image, target 381 | 382 | class RandomLightingNoise(object): 383 | def __init__(self): 384 | self.perms = ((0, 1, 2), (0, 2, 1), 385 | (1, 0, 2), (1, 2, 0), 386 | (2, 0, 1), (2, 1, 0)) 387 | def __call__(self, image, target): 388 | if rand.randint(2): 389 | swap = self.perms[rand.randint(len(self.perms))] 390 | shuffle = SwapChannels(swap) # shuffle channels 391 | image = shuffle(image) 392 | return image, target 393 | 394 | class ConvertColor(object): 395 | def __init__(self, current='BGR', transform='HSV'): 396 | self.transform = transform 397 | self.current = current 398 | 399 | def __call__(self, image, target): 400 | if self.current == 'BGR' and self.transform == 'HSV': 401 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 402 | elif self.current == 'HSV' and self.transform == 'BGR': 403 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 404 | else: 405 | raise NotImplementedError 406 | return image, target 407 | 408 | class SwapChannels(object): 409 | def __init__(self, swaps): 410 | self.swaps = swaps 411 | def __call__(self, image): 412 | image = image[:, :, self.swaps] 413 | return image 414 | 415 | class PhotometricDistort(object): 416 | def __init__(self): 417 | self.pd = [ 418 | RandomContrast(), 419 | ConvertColor(transform='HSV'), 420 | RandomSaturation(), 421 | RandomHue(), 422 | ConvertColor(current='HSV', transform='BGR'), 423 | RandomContrast() 424 | ] 425 | self.rand_brightness = RandomBrightness() 426 | self.rand_light_noise = RandomLightingNoise() 427 | 428 | def __call__(self,clip,target): 429 | imgs = [] 430 | for img in clip: 431 | img = np.asarray(img).astype('float32') 432 | img, target = self.rand_brightness(img, target) 433 | if rand.randint(2): 434 | distort = Compose(self.pd[:-1]) 435 | else: 436 | distort = Compose(self.pd[1:]) 437 | img, target = distort(img, target) 438 | img, target = self.rand_light_noise(img, target) 439 | imgs.append(Image.fromarray(img.astype('uint8'))) 440 | return imgs, target 441 | ''' 442 | # NOTICE: if used for mask, need to change 443 | class Expand(object): 444 | def __init__(self, mean): 445 | self.mean = mean 446 | def __call__(self, clip, target): 447 | if rand.randint(2): 448 | return clip,target 449 | imgs = [] 450 | masks = [] 451 | image = np.asarray(clip[0]).astype('float32') 452 | height, width, depth = image.shape 453 | ratio = rand.uniform(1, 4) 454 | left = rand.uniform(0, width*ratio - width) 455 | top = rand.uniform(0, height*ratio - height) 456 | for i in range(len(clip)): 457 | image = np.asarray(clip[i]).astype('float32') 458 | expand_image = np.zeros((int(height*ratio), int(width*ratio), depth),dtype=image.dtype) 459 | expand_image[:, :, :] = self.mean 460 | expand_image[int(top):int(top + height),int(left):int(left + width)] = image 461 | imgs.append(Image.fromarray(expand_image.astype('uint8'))) 462 | expand_mask = torch.zeros((int(height*ratio), int(width*ratio)),dtype=torch.uint8) 463 | expand_mask[int(top):int(top + height),int(left):int(left + width)] = target['masks'][i] 464 | masks.append(expand_mask) 465 | boxes = target['boxes'].numpy() 466 | boxes[:, :2] += (int(left), int(top)) 467 | boxes[:, 2:] += (int(left), int(top)) 468 | target['boxes'] = torch.tensor(boxes) 469 | target['masks']=torch.stack(masks) 470 | return imgs, target 471 | ''' 472 | 473 | class RandomHorizontalFlip(object): 474 | def __init__(self, p=0.5): 475 | self.p = p 476 | 477 | def __call__(self, img, target): 478 | if random.random() < self.p: 479 | # NOTE: caption for 'left' and 'right' should also change 480 | caption = target['caption'] 481 | target['caption'] = caption.replace('left', '@').replace('right', 'left').replace('@', 'right') 482 | return hflip(img, target) 483 | return img, target 484 | 485 | class RandomVerticalFlip(object): 486 | def __init__(self, p=0.5): 487 | self.p = p 488 | 489 | def __call__(self, img, target): 490 | if random.random() < self.p: 491 | return vflip(img, target) 492 | return img, target 493 | 494 | 495 | class RandomResize(object): 496 | def __init__(self, sizes, max_size=None): 497 | assert isinstance(sizes, (list, tuple)) 498 | self.sizes = sizes 499 | self.max_size = max_size 500 | 501 | def __call__(self, img, target=None): 502 | size = random.choice(self.sizes) 503 | return resize(img, target, size, self.max_size) 504 | 505 | 506 | class RandomPad(object): 507 | def __init__(self, max_pad): 508 | self.max_pad = max_pad 509 | 510 | def __call__(self, img, target): 511 | pad_x = random.randint(0, self.max_pad) 512 | pad_y = random.randint(0, self.max_pad) 513 | return pad(img, target, (pad_x, pad_y)) 514 | 515 | 516 | class RandomSelect(object): 517 | """ 518 | Randomly selects between transforms1 and transforms2, 519 | with probability p for transforms1 and (1 - p) for transforms2 520 | """ 521 | def __init__(self, transforms1, transforms2, p=0.5): 522 | self.transforms1 = transforms1 523 | self.transforms2 = transforms2 524 | self.p = p 525 | 526 | def __call__(self, img, target): 527 | if random.random() < self.p: 528 | return self.transforms1(img, target) 529 | return self.transforms2(img, target) 530 | 531 | 532 | class ToTensor(object): 533 | def __call__(self, clip, target): 534 | img = [] 535 | for im in clip: 536 | img.append(F.to_tensor(im)) 537 | return img, target 538 | 539 | 540 | class RandomErasing(object): 541 | 542 | def __init__(self, *args, **kwargs): 543 | self.eraser = T.RandomErasing(*args, **kwargs) 544 | 545 | def __call__(self, img, target): 546 | return self.eraser(img), target 547 | 548 | 549 | class Normalize(object): 550 | def __init__(self, mean, std): 551 | self.mean = mean 552 | self.std = std 553 | 554 | def __call__(self, clip, target=None): 555 | image = [] 556 | for im in clip: 557 | image.append(F.normalize(im, mean=self.mean, std=self.std)) 558 | if target is None: 559 | return image, None 560 | target = target.copy() 561 | h, w = image[0].shape[-2:] 562 | if "boxes" in target: 563 | boxes = target["boxes"] 564 | boxes = box_xyxy_to_cxcywh(boxes) 565 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 566 | target["boxes"] = boxes 567 | return image, target 568 | 569 | 570 | class Compose(object): 571 | def __init__(self, transforms): 572 | self.transforms = transforms 573 | 574 | def __call__(self, image, target): 575 | for t in self.transforms: 576 | image, target = t(image, target) 577 | return image, target 578 | 579 | def __repr__(self): 580 | format_string = self.__class__.__name__ + "(" 581 | for t in self.transforms: 582 | format_string += "\n" 583 | format_string += " {0}".format(t) 584 | format_string += "\n)" 585 | return format_string 586 | -------------------------------------------------------------------------------- /actionvos_metrics.py: -------------------------------------------------------------------------------- 1 | # evaluation metrics for actionvos, modified from GRES (https://github.com/henghuiding/ReLA) 2 | # cal ious and acc between positive/negative objs in 2 folders 3 | # every obj in each frame is treated as an obj 4 | 5 | # NOTE: our annotators found some errors in original 330 videos val_human split 6 | # which was used for the papers' results. 7 | # These videos are marked as 8 | # 'missing == True' or 'redundant == True' or 'other == True' (see val_human.json). 9 | # in this file, we filter these videos 10 | # and use 294 videos as val_human for evaluation. 11 | # This only change metrics within 1% 12 | # and would not influence any conclusions in our paper. 13 | 14 | # NOTE: due to different random seeds in training 15 | # ~1% difference comparing to our papers' tables is normal. 16 | # if you found your reproduced results are far from ours, 17 | # please contact me by email oyly(at)iis.u-tokyo.ac.jp. 18 | 19 | import numpy as np 20 | from tqdm import tqdm 21 | import os 22 | from PIL import Image 23 | import json 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--pred_path', type=str, required=True) 28 | parser.add_argument('--gt_path', type=str, required=True) 29 | parser.add_argument('--split_json', type=str, required=True) 30 | 31 | def cal_gres_all(gt_path,pred_path,split_json,filter=True): 32 | meta = json.load(open(split_json)) 33 | mious_pos = [] 34 | mious_neg = [] 35 | cis_pos = [] 36 | cus_pos = [] 37 | cis_neg = [] 38 | cus_neg = [] 39 | gious = [] 40 | TN,TP,FN,FP = 0,0,0,0 41 | N_Video = 0 42 | for seq in tqdm(meta): 43 | if filter: # skip some errored seq 44 | if seq['redundant'] or seq['other'] or seq['missing']: 45 | continue 46 | N_Video += 1 47 | folder_name = '{:08d}_{}_{}_{}'.format(seq['seq_id'],seq['video'],seq['verb'],seq['noun']) 48 | for frame in seq['sparse_frames']: 49 | mask_gt_path = os.path.join(gt_path,folder_name,frame.replace('jpg','png')) 50 | mask_pred_path = os.path.join(pred_path,folder_name,frame.replace('jpg','png')) 51 | gt = np.array(Image.open(mask_gt_path)) 52 | pred = np.array(Image.open(mask_pred_path)) 53 | for k in seq['object_classes'].keys(): 54 | # for positive obj 55 | if seq['object_classes'][k]['positive']: 56 | color = int(k) 57 | p = np.where(pred==color,1,0).astype(float) 58 | g = np.where(gt==color,1,0).astype(float) 59 | intersection = (p*g).sum() 60 | union = (p+g).sum()-intersection 61 | if union <= 0: 62 | # no-target and no detected 63 | # NOTE: for union = 0 samples, we do not count for mIoUs 64 | # but count them for gIoU and accs. 65 | gious.append(1) 66 | TN += 1 67 | else: 68 | mious_pos.append(intersection/union) 69 | cis_pos.append(intersection) 70 | cus_pos.append(union) 71 | #some cases, pos obj may not visible in this frame 72 | if g.sum() <= 0: 73 | # no-target but detected 74 | FP += 1 75 | gious.append(0) 76 | else: 77 | gious.append(intersection/union) 78 | if p.sum() <= 0: 79 | # has target but not detected 80 | FN += 1 81 | else: 82 | # has target and detected 83 | TP += 1 84 | else: 85 | color = int(k) 86 | p = np.where(pred==color,1,0).astype(float) 87 | g = np.where(gt==color,1,0).astype(float) 88 | intersection = (p*g).sum() 89 | union = (p+g).sum()-intersection 90 | if union <= 0: 91 | # no-target and no detected 92 | gious.append(1) 93 | TN += 1 94 | else: 95 | mious_neg.append(intersection/union) 96 | cis_neg.append(intersection) 97 | cus_neg.append(union) 98 | #some cases, neg obj may not visible in this frame 99 | if g.sum() <= 0: 100 | # no-target but detected 101 | FP += 1 102 | gious.append(0) 103 | else: 104 | #gious.append(intersection/union) 105 | if p.sum() <= 0: 106 | # neg target and no detected 107 | gious.append(1) 108 | TN += 1 109 | else: 110 | # neg target and detected 111 | FP += 1 112 | gious.append(0) 113 | print(f'----- evaluation on {N_Video} videos -----') 114 | print('pos-mIoU: ', sum(mious_pos)/len(mious_pos)) 115 | print('neg-mIoU: ', sum(mious_neg)/len(mious_neg)) 116 | print('pos-cIoU: ', sum(cis_pos)/sum(cus_pos)) 117 | print('neg-cIoU: ', sum(cis_neg)/sum(cus_neg)) 118 | print('gIoU: ', sum(gious)/len(gious)) 119 | print('acc: ',(TN+TP)/(TN+TP+FN+FP)) 120 | print('TN,TP,FN,FP',TN,TP,FN,FP) 121 | print(f'------------------------------------') 122 | 123 | def main(): 124 | args = parser.parse_args() 125 | cal_gres_all(args.gt_path,args.pred_path,args.split_json,True) 126 | 127 | if __name__ == '__main__': 128 | main() -------------------------------------------------------------------------------- /annotations/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/annotations/00000.png -------------------------------------------------------------------------------- /copy_rf_actionvos_files.py: -------------------------------------------------------------------------------- 1 | # move modified files to referformer 2 | file_list = [ 3 | "datasets/__init__.py", 4 | "datasets/actionvos.py", 5 | "datasets/transforms_video_actionvos.py", 6 | "opts.py", 7 | "main_actionvos.py", 8 | "inference_actionvos.py", 9 | "scripts/train_actionvos.sh", 10 | "scripts/test_actionvos.sh", 11 | "models/referformer.py", 12 | "models/segmentation.py", 13 | "models/criterion.py", 14 | ] 15 | 16 | import shutil 17 | import os 18 | import json 19 | from tqdm import tqdm 20 | 21 | def move_files(src_path='RF_ActionVOS',dst_path='ReferFormer'): 22 | for file in file_list: 23 | filename = file.split('/')[-1] 24 | src_file = os.path.join(src_path,filename) 25 | dst_file = os.path.join(dst_path,file) 26 | if os.path.exists(dst_file): 27 | os.remove(dst_file) 28 | shutil.copy(src_file, dst_file) 29 | 30 | move_files() 31 | 32 | # generate objects_category.json for RF 33 | # generate meta_expressions.json for RF 34 | # if you want to change input expressions, 35 | # i.e., ablation study of language prompts, change meta_expressions.json files 36 | 37 | def generate_meta_json(actionvos_path='dataset_visor',language_prompt='promptaction',split='train'): 38 | assert language_prompt in ['noaction','action','promptaction'] 39 | # for meta_expressions.json 40 | datas = json.load(open(os.path.join(actionvos_path,'ImageSets',f'{split}.json'))) 41 | meta_exp = {} 42 | meta_exp['videos'] = {} 43 | obj_cls = {} 44 | obj_cls['videos'] = {} 45 | for seq in tqdm(datas): 46 | exp_dict = {} 47 | exp_dict['expressions'] = {} 48 | obj_dict = {} 49 | obj_dict['objects'] = {} 50 | for i,k in enumerate(seq['object_classes'].keys()): 51 | exp_dict['expressions'][str(i)] = {} 52 | if language_prompt == 'noaction': 53 | exp_dict['expressions'][str(i)]['exp'] = seq['object_classes'][k]['name'] 54 | elif language_prompt == 'action': 55 | exp_dict['expressions'][str(i)]['exp'] = seq['object_classes'][k]['name']+', '+seq['narration'] 56 | elif language_prompt == 'promptaction': 57 | exp_dict['expressions'][str(i)]['exp'] = seq['object_classes'][k]['name']+' used in the action of '+seq['narration'] 58 | else: 59 | pass 60 | exp_dict['expressions'][str(i)]['obj_id'] = k 61 | exp_dict['expressions'][str(i)]['positive'] = seq['object_classes'][k]['positive'] 62 | exp_dict['expressions'][str(i)]['class_id'] = seq['object_classes'][k]['class_id'] 63 | obj_dict['objects'][k]={'category':seq['object_classes'][k]['name'],'positive':seq['object_classes'][k]['positive'],'class_id':seq['object_classes'][k]['class_id']} 64 | exp_dict['frames'] = [] 65 | for f in seq['sparse_frames']: 66 | exp_dict['frames'].append(f.replace('.jpg','')) 67 | folder_name = '{:08d}_{}_{}_{}'.format(seq['seq_id'],seq['video'],seq['verb'],seq['noun']) 68 | meta_exp['videos'][folder_name] = exp_dict 69 | obj_cls['videos'][folder_name] = obj_dict 70 | 71 | meta_exp_json_object = json.dumps(meta_exp) 72 | obj_cls_json_dist = os.path.join(actionvos_path,'ImageSets', f'{split}_meta_expressions_{language_prompt}.json') 73 | with open(obj_cls_json_dist, "w") as f: 74 | f.write(meta_exp_json_object) 75 | 76 | obj_cls_json_object = json.dumps(obj_cls) 77 | obj_cls_json_dist = os.path.join(actionvos_path,'ImageSets', f'{split}_objects_category.json') 78 | with open(obj_cls_json_dist, "w") as f: 79 | f.write(obj_cls_json_object) 80 | 81 | generate_meta_json(split='train') 82 | generate_meta_json(split='val') -------------------------------------------------------------------------------- /data_prepare_visor.py: -------------------------------------------------------------------------------- 1 | # this script prepares ActionVOS data from VISOR 2 | # Action-aware labeling and action-guided weights are also generated by this script 3 | # part of codes from VISOR-VOS repo 4 | import numpy as np 5 | import pandas as pd 6 | import json 7 | import os 8 | import functools 9 | import cv2 10 | from tqdm import tqdm 11 | from PIL import Image 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--VISOR_PATH', type=str, required=True) 16 | args = parser.parse_args() 17 | VISOR_PATH = args.VISOR_PATH 18 | 19 | SAVE_PATH = 'dataset_visor' 20 | 21 | # weight_only=True # only save weight 22 | frame_mapping_json = os.path.join(VISOR_PATH,'frame_mapping.json') 23 | 24 | EK_100_action_csv = {'train':'annotations/EPIC_100_train.csv', 25 | 'val':'annotations/EPIC_100_validation.csv',} 26 | visor_hos_json = {'train':'annotations/visor_hos_train.json', 27 | 'val':'annotations/visor_hos_val.json',} 28 | 29 | os.makedirs(SAVE_PATH,exist_ok=True) 30 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages_Sparse'),exist_ok=True) 31 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages_Sparse','train'),exist_ok=True) 32 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages_Sparse','val'),exist_ok=True) 33 | os.makedirs(os.path.join(SAVE_PATH,'ImageSets'),exist_ok=True) 34 | os.makedirs(os.path.join(SAVE_PATH,'Annotations_Sparse'),exist_ok=True) 35 | os.makedirs(os.path.join(SAVE_PATH,'Annotations_Sparse','train'),exist_ok=True) 36 | os.makedirs(os.path.join(SAVE_PATH,'Annotations_Sparse','val'),exist_ok=True) 37 | os.makedirs(os.path.join(SAVE_PATH,'Weights_Sparse'),exist_ok=True) 38 | os.makedirs(os.path.join(SAVE_PATH,'Weights_Sparse','train'),exist_ok=True) 39 | os.makedirs(os.path.join(SAVE_PATH,'Weights_Sparse','val'),exist_ok=True) 40 | 41 | palette = Image.open('annotations/00000.png').getpalette() 42 | 43 | def generate_action_with_mask(split='train'): 44 | ''' 45 | find VISOR mask annotations for each action in EK-100 46 | return list of dicted RGB frames and annotation masks and weights 47 | [ 48 | { 49 | 'seq_id': int to identify sequence id, 50 | 'video': str to identify video source, 51 | 'verb''verb_class''noun''noun_class': verb noun annotation from EK-100, 52 | 'narration': narration from EK-100, 53 | 'start': start frame index from EK-100, 54 | 'end': end frame index from EK-100, 55 | 'object_classes': {1:{'name':xxx,'class_id':xxx,'positive':0/1,'handbox':0/1,'narration':0/1}, } for mask object labels in annotations, 56 | 'sparse_frames': ['xxx.jpg','yyy.jpg'] for the SparseAnnotated frames in VISOR, but in EK-100 index 57 | }, ... 58 | ] 59 | for weights: 60 | *3 for negative obj mask. 61 | *2 for hand | narration obj mask. 62 | *4 for hand & narration obj mask. 63 | *1 for other areas 64 | ''' 65 | assert split in ['train','val'] 66 | # read visor, EK-100, visor-hos 67 | frame_mapping_visor2ek = json.load(open(frame_mapping_json)) 68 | EK_100_actions = pd.concat([pd.read_csv(EK_100_action_csv['train']),pd.read_csv(EK_100_action_csv['val'])]) 69 | hos_annotations = map_hos(json.load(open(visor_hos_json[split]))) 70 | 71 | list_of_dict = [] 72 | seq_id = 0 73 | all_sparse_frames = 0 74 | for video in tqdm(frame_mapping_visor2ek.keys()): 75 | EK_100_actions_video = EK_100_actions[EK_100_actions['video_id']==video] 76 | if len(EK_100_actions_video) == 0: 77 | # not in EK-100 action train-val 78 | continue 79 | visor_annotations_json = os.path.join(VISOR_PATH,'GroundTruth-SparseAnnotations/annotations_corrected',split,video+'.json') 80 | if not os.path.exists(visor_annotations_json): 81 | # not in this split 82 | continue 83 | frame_mapping_ek2visor = reverse_mapping(frame_mapping_visor2ek[video]) 84 | EK_100_actions_list = EK_100_actions_video.values.tolist() 85 | EK_100_actions_list = sorted(EK_100_actions_list, key=functools.cmp_to_key(cmp)) 86 | frame_ek_np = np.array(sorted(list(frame_mapping_ek2visor.keys()))) 87 | 88 | visor_annotations_sparse = json.load(open(visor_annotations_json)) 89 | action_index = 0 90 | while action_index < len(EK_100_actions_list): 91 | action = EK_100_actions_list[action_index] 92 | dict = {} 93 | dict['seq_id'] = seq_id+1 94 | dict['video'] = video 95 | dict['start'] = action[6] 96 | dict['narration'] = action[8] 97 | dict['verb'] = action[9] 98 | dict['verb_class'] = action[10] 99 | dict['noun'] = action[11] 100 | dict['noun_class'] = action[12] 101 | 102 | dict['end'] = EK_100_actions_list[action_index][7] 103 | 104 | # check sparse annotation 105 | start_frame_ek = 'frame_{:010d}.jpg'.format(dict['start']) 106 | end_frame_ek = 'frame_{:010d}.jpg'.format(dict['end']) 107 | start = np.searchsorted(frame_ek_np,start_frame_ek) 108 | end = np.searchsorted(frame_ek_np,end_frame_ek) 109 | annotated_frames = [] 110 | if start >= len(frame_ek_np): 111 | annotated_frames.append('no_annotation_for_this_action') 112 | elif start == end: 113 | if frame_ek_np[end] == end_frame_ek: 114 | annotated_frames.append(end_frame_ek) 115 | else: 116 | annotated_frames.append('no_annotation_for_this_action') 117 | else: 118 | for i in range(start,end): 119 | annotated_frames.append(frame_ek_np[i]) 120 | if end < len(frame_ek_np) and frame_ek_np[end] == end_frame_ek: 121 | annotated_frames.append(end_frame_ek) 122 | dict['sparse_frames'] = annotated_frames 123 | if len(annotated_frames) <= 1: 124 | # filter actions with 0 or 1 sparse annotations 125 | action_index += 1 126 | continue 127 | # generate sparse masks and pseudo labels 128 | res = generate_sparse_masks_pseudo_labels_weights(visor_annotations_sparse,frame_mapping_ek2visor,annotated_frames,hos_annotations=hos_annotations,all_nouns_cls=action[14]) 129 | if res is None: 130 | # filter actions with no objects 131 | action_index += 1 132 | continue 133 | seq_id += 1 134 | # prepare paths 135 | # EK_path_img = os.path.join(EK_PATH,video.split('_')[0],'rgb_frames',video) 136 | save_path_img_sparse = os.path.join(SAVE_PATH,'JPEGImages_Sparse',split,'{:08d}_{}_{}_{}'.format(seq_id,video,dict['verb'],dict['noun'])) 137 | os.makedirs(save_path_img_sparse,exist_ok=True) 138 | save_path_png_sparse = save_path_img_sparse.replace('JPEGImages_Sparse','Annotations_Sparse') 139 | os.makedirs(save_path_png_sparse,exist_ok=True) 140 | save_path_weights_sparse = save_path_img_sparse.replace('JPEGImages_Sparse','Weights_Sparse') 141 | os.makedirs(save_path_weights_sparse,exist_ok=True) 142 | # save masks 143 | for i,m in enumerate(res[0]): 144 | m_pil = Image.fromarray(m,mode='P') 145 | m_pil = m_pil.resize((854,480),Image.NEAREST) 146 | m_pil.putpalette(palette) 147 | m_pil.save(os.path.join(save_path_png_sparse,annotated_frames[i].replace('jpg','png'))) 148 | img = cv2.imread(os.path.join(VISOR_PATH,'GroundTruth-SparseAnnotations','rgb_frames',split,video,frame_mapping_ek2visor[annotated_frames[i]])) 149 | img = cv2.resize(img, (854, 480), interpolation=cv2.INTER_LINEAR) 150 | cv2.imwrite(os.path.join(save_path_img_sparse,annotated_frames[i]),img) 151 | all_sparse_frames += 1 152 | # save weights 153 | for i,m in enumerate(res[2]): 154 | m_pil = Image.fromarray(m,mode='P') 155 | m_pil = m_pil.resize((854,480),Image.NEAREST) 156 | m_pil.putpalette(palette) 157 | m_pil.save(os.path.join(save_path_weights_sparse,annotated_frames[i].replace('jpg','png'))) 158 | dict['object_classes'] = res[1] 159 | 160 | list_of_dict.append(dict) 161 | action_index += 1 162 | json_object = json.dumps(list_of_dict) 163 | with open(os.path.join(SAVE_PATH,'ImageSets',split+".json"), "w") as f: 164 | f.write(json_object) 165 | print('finished for {} set. {:d} actions, {:d} sparse frames'.format(split,seq_id,all_sparse_frames)) 166 | 167 | def cmp(x,y): 168 | a_x,b_x,c_x = map(int,x[0][1:].split('_')) 169 | a_y,b_y,c_y = map(int,y[0][1:].split('_')) 170 | if (a_x,b_x,c_x)<(a_y,b_y,c_y): 171 | return -1 172 | return 1 173 | 174 | def reverse_mapping(dict): 175 | # change the key:value of the dict to value:key 176 | res = {} 177 | for k in dict.keys(): 178 | res[dict[k]]=k 179 | return res 180 | 181 | def map_hos(hos_annotations): 182 | # get hos annotations to {'file_name':{'bboxs':[bbox1,bbox2],'masks':[segment1,segment2]}} 183 | res = {} 184 | id2file = {} 185 | for img in hos_annotations['images']: 186 | id2file[img['id']] = img['file_name'] 187 | for ann in hos_annotations['annotations']: 188 | if ann['category_id'] == 2: 189 | file_name = id2file[ann['image_id']] 190 | if file_name not in res: 191 | res[file_name] = {'bboxs':[],'masks':[]} 192 | res[file_name]['bboxs'].append(ann['bbox']) 193 | res[file_name]['masks'].append(ann['segmentation']) 194 | return res 195 | 196 | def generate_sparse_masks_pseudo_labels_weights(visor_annotations,frame_mapping_ek2visor,annotated_frames,hos_annotations=None,all_nouns_cls=[]): 197 | ''' 198 | generate VISOR masks, pseudo cls labels, and loss weights for the action sequence 199 | read sparse annotated mask from VISOR 200 | and save the obj-color mapping 201 | 202 | annotated_frames = [framexxxxxxxxxx.jpg,...,] 203 | 204 | return ([mask_np],obj_map,[weight_np]) 205 | ''' 206 | # get all annotated objects 207 | all_obj_map = get_all_obj_map(visor_annotations,frame_mapping_ek2visor,annotated_frames) 208 | if all_obj_map is None: 209 | return None 210 | # the following part is action-aware labeling 211 | # get objects in hand boxs 212 | hand_obj = get_hand_obj(visor_annotations,frame_mapping_ek2visor,annotated_frames,hos_annotations) 213 | # get objects in narrations 214 | narration_obj = get_narration_obj(all_obj_map,all_nouns_cls) 215 | # label postive objects 216 | for obj in all_obj_map.keys(): 217 | # deal with hand 218 | if hand_obj is None: 219 | all_obj_map[obj].append(0) 220 | else: 221 | if obj in hand_obj: 222 | all_obj_map[obj].append(1) 223 | else: 224 | all_obj_map[obj].append(0) 225 | # deal with narration 226 | if narration_obj is None: 227 | all_obj_map[obj].append(0) 228 | else: 229 | if obj in narration_obj: 230 | all_obj_map[obj].append(1) 231 | else: 232 | all_obj_map[obj].append(0) 233 | # deal with pos 234 | if all_obj_map[obj][-1] or all_obj_map[obj][-2]: 235 | all_obj_map[obj].append(1) 236 | else: 237 | all_obj_map[obj].append(0) 238 | # the following part is mask and action-guided weights 239 | masks = [] 240 | weights = [] 241 | for frame in annotated_frames: 242 | frame_visor = frame_mapping_ek2visor[frame] 243 | # draw pseudo mask 244 | for ann in visor_annotations['video_annotations']: 245 | if ann['image']['name'] == frame_visor: 246 | mask_np = np.zeros([1080,1920],dtype=np.uint8) 247 | for entity in ann['annotations']: 248 | polygons = [] 249 | for segment in entity['segments']: 250 | polygons.append(segment) 251 | ps = [] 252 | #store the polygons in one list. One object may has more than 1 polygon 253 | for poly in polygons: 254 | if poly == []: 255 | poly = [[0.0, 0.0]] 256 | ps.append(np.array(poly, dtype=np.int32)) 257 | if (entity['name'] in all_obj_map.keys()): 258 | color = all_obj_map[entity['name']][0] 259 | cv2.fillPoly(mask_np, ps, (color, color, color)) 260 | masks.append(mask_np) 261 | break 262 | # draw action-guided weight 263 | for ann in visor_annotations['video_annotations']: 264 | if ann['image']['name'] == frame_visor: 265 | weight_np = np.zeros([1080,1920],dtype=np.uint8) 266 | if frame_visor in hos_annotations: 267 | # draw hand mask 268 | hand_mask_np = np.zeros([1080,1920],dtype=np.uint8) 269 | hand_bbox_np = np.zeros([1080,1920],dtype=np.uint8) 270 | for hom in hos_annotations[frame_visor]['masks']: 271 | ps = [] 272 | for seg in hom: 273 | seg = np.array(seg, dtype=np.int32) 274 | poly = np.reshape(seg,(-1,2)) 275 | ps.append(poly) 276 | cv2.fillPoly(hand_mask_np, ps, (2, 2, 2)) 277 | # draw hand bboxs 278 | bboxs = hos_annotations[frame_visor]['bboxs']#xywh 279 | for bbox in bboxs: 280 | #bbox_xyxy = [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]] 281 | hand_bbox_np[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1 282 | for entity in ann['annotations']: 283 | polygons = [] 284 | for segment in entity['segments']: 285 | polygons.append(segment) 286 | ps = [] 287 | #store the polygons in one list. One object may has more than 1 polygon 288 | for poly in polygons: 289 | if poly == []: 290 | poly = [[0.0, 0.0]] 291 | ps.append(np.array(poly, dtype=np.int32)) 292 | # for narration objs 293 | if ((narration_obj is not None) and (entity['name'] in narration_obj)): 294 | cv2.fillPoly(weight_np, ps, (2, 2, 2)) 295 | else: 296 | # for negative objs 297 | if entity['name'] in all_obj_map.keys(): 298 | if frame_visor in hos_annotations: 299 | if not segment_in_box(ps,hos_annotations[frame_visor]['bboxs']): 300 | cv2.fillPoly(weight_np, ps, (3, 3, 3)) 301 | else: 302 | if all_obj_map[entity['name']][-1] == 0: 303 | cv2.fillPoly(weight_np, ps, (3, 3, 3)) 304 | # for hand objs 305 | if frame_visor in hos_annotations: 306 | weight_np = np.where(hand_bbox_np==1, weight_np*2, weight_np) 307 | weight_np = np.where(weight_np == 0, hand_mask_np, weight_np) 308 | weight_np = np.where(weight_np == 0, 1, weight_np) 309 | weights.append(weight_np) 310 | break 311 | # change all_obj_map to {color:{'name': ,'class_id':, 'positive':, 'handbox':, 'narration':}} 312 | obj_map_res = {} 313 | for name in all_obj_map.keys(): 314 | color = all_obj_map[name][0] 315 | obj_map_res[color] = {'name':name,'class_id':all_obj_map[name][1],'handbox':all_obj_map[name][2], 'narration':all_obj_map[name][3],'positive':all_obj_map[name][4]} 316 | return (masks,obj_map_res,weights) 317 | 318 | def segment_in_box(ps,bboxs): 319 | ''' 320 | check if a segmentation mask (in ps) is inside any of the bboxs. 321 | ''' 322 | segment = np.zeros([1080,1920],dtype=np.uint8) 323 | cv2.fillPoly(segment, ps, (1, 1, 1)) 324 | box = np.zeros([1080,1920],dtype=np.uint8) 325 | for bbox in bboxs: 326 | box[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1 327 | return (segment*box).any() 328 | 329 | def get_all_obj_map(visor_annotations,frame_mapping_ek2visor,annotated_frames): 330 | ''' 331 | get obj_map: {name:[color,class_id]}. would be added to [color,class_id,handbox,narration,pos] 332 | should be sorted here. set() would random hash every run 333 | ''' 334 | objects = set() 335 | cls_id = {} 336 | for frame in annotated_frames: 337 | frame_visor = frame_mapping_ek2visor[frame] 338 | for ann in visor_annotations['video_annotations']: 339 | if ann['image']['name'] == frame_visor: 340 | for entity in ann['annotations']: #loop over each object 341 | if len(entity["segments"]) > 0: #if there is annotation for this object, add it 342 | objects.add(entity["name"]) 343 | cls_id[entity["name"]] = entity["class_id"] 344 | break 345 | if len(objects) == 0: 346 | return None 347 | objects = list(objects) 348 | objects = sorted(objects)# must be sorted! 349 | obj_map = {} 350 | for i,obj in enumerate(objects): 351 | obj_map[obj] = [i+1,cls_id[obj]] 352 | return obj_map 353 | 354 | def get_hand_obj(visor_annotations,frame_mapping_ek2visor,annotated_frames,hos_annotations): 355 | ''' 356 | get obj: set of positive objs (in class name) 357 | ''' 358 | objects = set() 359 | for frame in annotated_frames: 360 | frame_visor = frame_mapping_ek2visor[frame] 361 | for ann in visor_annotations['video_annotations']: 362 | if ann['image']['name'] == frame_visor: 363 | # draw all annotation 364 | mask_np = np.zeros([1080,1920],dtype=np.uint8)# yx 365 | tmp_map = {} 366 | for i,entity in enumerate(ann['annotations']): 367 | polygons = [] 368 | for segment in entity['segments']: 369 | polygons.append(segment) 370 | ps = [] 371 | #store the polygons in one list. One object may has more than 1 polygon 372 | for poly in polygons: 373 | if poly == []: 374 | poly = [[0.0, 0.0]] 375 | ps.append(np.array(poly, dtype=np.int32)) 376 | cv2.fillPoly(mask_np, ps, (i+1, i+1, i+1)) 377 | tmp_map[i+1] = entity['name'] 378 | if entity['name'] in ['left hand','right hand']: 379 | objects.add(entity['name']) 380 | # get bboxs and bbox objs 381 | if frame_visor in hos_annotations: 382 | bboxs = hos_annotations[frame_visor]['bboxs']#xywh 383 | for bbox in bboxs: 384 | #bbox_xyxy = [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]] 385 | crop = mask_np[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] 386 | unique = np.unique(crop) 387 | for u in unique: 388 | if u != 0: 389 | objects.add(tmp_map[u]) 390 | break 391 | if len(objects) == 0: 392 | return None 393 | return objects 394 | 395 | def get_narration_obj(all_obj_map,all_nouns_cls): 396 | # EK-100 annotation all_nouns_cls 397 | objects = set() 398 | all_nouns_cls = eval(all_nouns_cls) 399 | for obj in all_obj_map.keys(): 400 | if all_obj_map[obj][1] in all_nouns_cls: 401 | objects.add(obj) 402 | if len(objects) == 0: 403 | return None 404 | return objects 405 | 406 | def main(): 407 | generate_action_with_mask('train') 408 | generate_action_with_mask('val') 409 | 410 | if __name__ == '__main__': 411 | main() -------------------------------------------------------------------------------- /data_prepare_vost.py: -------------------------------------------------------------------------------- 1 | # change VOST and VSCOS to ActionVOS(RVOS) settings 2 | # val ~ 3 | import numpy as np 4 | import pandas as pd 5 | import yaml 6 | import json 7 | import os 8 | import functools 9 | import cv2 10 | import shutil 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--VOST_PATH', type=str, required=True) 17 | args = parser.parse_args() 18 | VOST_PATH = args.VOST_PATH 19 | 20 | SAVE_PATH = 'dataset_vost' 21 | os.makedirs(SAVE_PATH,exist_ok=True) 22 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages'),exist_ok=True) 23 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages','train'),exist_ok=True) 24 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages','val'),exist_ok=True) 25 | os.makedirs(os.path.join(SAVE_PATH,'ImageSets'),exist_ok=True) 26 | os.makedirs(os.path.join(SAVE_PATH,'Annotations'),exist_ok=True) 27 | os.makedirs(os.path.join(SAVE_PATH,'Annotations','train'),exist_ok=True) 28 | os.makedirs(os.path.join(SAVE_PATH,'Annotations','val'),exist_ok=True) 29 | 30 | palette = Image.open('annotations/00000.png').getpalette() 31 | 32 | def copy_files(src_folder, dest_folder): 33 | if not os.path.exists(dest_folder): 34 | os.makedirs(dest_folder) 35 | files = os.listdir(src_folder) 36 | for file_name in files: 37 | src_path = os.path.join(src_folder, file_name) 38 | dest_path = os.path.join(dest_folder, file_name) 39 | shutil.copy2(src_path, dest_path) 40 | 41 | def generate_action_with_mask(split='train',sampling='all'): 42 | ''' 43 | return list of dicted RGB frames and annotation masks and weights 44 | [ 45 | { 46 | 'seq_id': int to identify sequence id, 47 | 'src': str to identify video source 48 | 'video': original video name from VOST/VSCOS 49 | 'verb''verb_class''noun''noun_class': verb noun annotation from EK-100, -1 for not available 50 | 'narration': narration from VOST/VSCOS, 51 | 'start': start frame index from EK-100, 52 | 'end': end frame index from EK-100,this 2 fields should be zero for VOST 53 | 'object_classes': {1:{'name':xxx,'class_id':xxx,'positive':0/1,'hand_box':0/1,'narration_obj':0/1}, } for mask object labels in annotations, 54 | 'sparse_frames': ['xxx.jpg','yyy.jpg'] for the frames indexs 55 | }, ... 56 | ] 57 | ''' 58 | assert split in ['train','val'] 59 | assert sampling in ['all'] 60 | # read VOST 61 | with open(os.path.join(VOST_PATH,'ImageSets',split+'.txt')) as f: 62 | VOST_actions = f.readlines() 63 | VOST_actions = [a.strip() for a in VOST_actions] 64 | 65 | list_of_dict = [] 66 | all_sparse_frames = 0 67 | seq_id = 0 68 | for video in tqdm(VOST_actions): 69 | seq_id += 1 70 | # copy images 71 | images_path = os.path.join(VOST_PATH,'JPEGImages',video) 72 | copy_files(images_path,os.path.join(SAVE_PATH,'JPEGImages',split,'{:08d}_{}'.format(seq_id,'_'.join(video.split('_')[1:])))) 73 | # copy masks, make ambigious to zeros, make instances as one 74 | src_mask_path = os.path.join(VOST_PATH,'Annotations_raw',video) 75 | dst_mask_path = os.path.join(SAVE_PATH,'Annotations',split,'{:08d}_{}'.format(seq_id,'_'.join(video.split('_')[1:]))) 76 | src_masks = os.listdir(src_mask_path) 77 | os.makedirs(dst_mask_path,exist_ok=True) 78 | for src_mask in src_masks: 79 | mask = Image.open(os.path.join(src_mask_path,src_mask)) 80 | mask_np = np.array(mask) 81 | mask_np[mask_np==255] = 0 82 | mask_np[mask_np>0] = 1 83 | mask_save = Image.fromarray(mask_np,mode='P') 84 | mask_save.putpalette(palette) 85 | mask_save.save(os.path.join(dst_mask_path,src_mask)) 86 | # save dict 87 | dict = {} 88 | dict['seq_id'] = seq_id 89 | dict['video'] = video 90 | dict['start'] = 0 91 | dict['end'] = 0 92 | dict['narration'] = ' '.join(video.split('_')[1:]) 93 | dict['verb'] = video.split('_')[1] 94 | dict['verb_class'] = -1 95 | dict['noun'] = ' '.join(video.split('_')[2:]) 96 | dict['noun_class'] = -1 97 | dict['sparse_frames'] = sorted(os.listdir(images_path)) 98 | dict['object_classes'] = {"1":{"name": dict['noun'], "class_id": -1, "handbox": 1, "narration": 1, "positive": 1}} 99 | list_of_dict.append(dict) 100 | all_sparse_frames += len(dict['sparse_frames']) 101 | #if seq_id>=2: 102 | #break 103 | 104 | # save json 105 | json_object = json.dumps(list_of_dict) 106 | with open(os.path.join(SAVE_PATH,'ImageSets',split+'_'+sampling+".json"), "w") as f: 107 | f.write(json_object) 108 | print('finished for {}_{} set. {:d} actions, {:d} sparse frames'.format(split,sampling,seq_id,all_sparse_frames)) 109 | 110 | def main(): 111 | generate_action_with_mask('val') 112 | 113 | if __name__ == '__main__': 114 | main() -------------------------------------------------------------------------------- /data_prepare_vscos.py: -------------------------------------------------------------------------------- 1 | # change VOST and VSCOS to ActionVOS(RVOS) settings 2 | # val ~ 3 | import numpy as np 4 | import pandas as pd 5 | import yaml 6 | import json 7 | import os 8 | import functools 9 | import cv2 10 | import shutil 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--VSCOS_PATH', type=str, required=True) 17 | args = parser.parse_args() 18 | VSCOS_PATH = args.VSCOS_PATH 19 | 20 | SAVE_PATH = 'dataset_vscos' 21 | os.makedirs(SAVE_PATH,exist_ok=True) 22 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages'),exist_ok=True) 23 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages','train'),exist_ok=True) 24 | os.makedirs(os.path.join(SAVE_PATH,'JPEGImages','val'),exist_ok=True) 25 | os.makedirs(os.path.join(SAVE_PATH,'ImageSets'),exist_ok=True) 26 | os.makedirs(os.path.join(SAVE_PATH,'Annotations'),exist_ok=True) 27 | os.makedirs(os.path.join(SAVE_PATH,'Annotations','train'),exist_ok=True) 28 | os.makedirs(os.path.join(SAVE_PATH,'Annotations','val'),exist_ok=True) 29 | 30 | palette = Image.open('annotations/00000.png').getpalette() 31 | 32 | def copy_files(src_folder, dest_folder): 33 | if not os.path.exists(dest_folder): 34 | os.makedirs(dest_folder) 35 | files = os.listdir(src_folder) 36 | for file_name in files: 37 | src_path = os.path.join(src_folder, file_name) 38 | dest_path = os.path.join(dest_folder, file_name) 39 | shutil.copy2(src_path, dest_path) 40 | 41 | def generate_action_with_mask(split='train',sampling='all'): 42 | ''' 43 | return list of dicted RGB frames and annotation masks and weights 44 | [ 45 | { 46 | 'seq_id': int to identify sequence id, 47 | 'src': str to identify video source 48 | 'video': original video name from VOST/VSCOS 49 | 'verb''verb_class''noun''noun_class': verb noun annotation from EK-100, -1 for not available 50 | 'narration': narration from VOST/VSCOS, 51 | 'start': start frame index from EK-100, 52 | 'end': end frame index from EK-100,this 2 fields should be zero for VOST 53 | 'object_classes': {1:{'name':xxx,'class_id':xxx,'positive':0/1,'hand_box':0/1,'narration_obj':0/1}, } for mask object labels in annotations, 54 | 'sparse_frames': ['xxx.jpg','yyy.jpg'] for the frames indexs 55 | }, ... 56 | ] 57 | ''' 58 | assert split in ['train','val'] 59 | assert sampling in ['all'] 60 | list_of_dict = [] 61 | all_sparse_frames = 0 62 | seq_id = 0 63 | # read VSCOS 64 | with open(os.path.join(VSCOS_PATH,'EPIC_{}_split'.format(split),'EPIC100_state_positive_{}.yaml'.format(split)), 'r') as f: 65 | VSCOS_actions = yaml.safe_load(f) 66 | for video in tqdm(VSCOS_actions.keys()): 67 | P = VSCOS_actions[video]['participant_id'] 68 | vid = VSCOS_actions[video]['video_id'] 69 | seq_id += 1 70 | src_mask_path = os.path.join(VSCOS_PATH,'EPIC_{}_split'.format(split),P,'anno_masks',vid,video) 71 | src_img_path = src_mask_path.replace('anno_masks','rgb_frames') 72 | dst_mask_path = os.path.join(SAVE_PATH,'Annotations',split,'{:08d}_{}_{}'.format(seq_id,VSCOS_actions[video]['verb'],VSCOS_actions[video]['noun'])) 73 | dst_img_path = dst_mask_path.replace('Annotations','JPEGImages') 74 | os.makedirs(dst_mask_path,exist_ok=True) 75 | os.makedirs(dst_img_path,exist_ok=True) 76 | for src_mask in os.listdir(src_mask_path): 77 | # copy masks 78 | mask = Image.open(os.path.join(src_mask_path,src_mask)).convert('P') 79 | mask.putpalette(palette) 80 | mask.save(os.path.join(dst_mask_path,src_mask)) 81 | # copy images 82 | src_img = src_mask.replace('png','jpg') 83 | shutil.copy2(os.path.join(src_img_path,src_img),os.path.join(dst_img_path,src_img)) 84 | # save dict 85 | dict = {} 86 | dict['seq_id'] = seq_id 87 | dict['video'] = video 88 | dict['start'] = VSCOS_actions[video]['start_frame'] 89 | dict['end'] = VSCOS_actions[video]['stop_frame'] 90 | dict['narration'] = VSCOS_actions[video]['narration'] 91 | dict['verb'] = VSCOS_actions[video]['verb'] 92 | dict['verb_class'] = VSCOS_actions[video]['verb_class'] 93 | dict['noun'] = VSCOS_actions[video]['noun'] 94 | dict['noun_class'] = VSCOS_actions[video]['noun_class'] 95 | dict['sparse_frames'] = sorted([m.replace('png','jpg') for m in os.listdir(src_mask_path)]) 96 | dict['object_classes'] = {"1":{"name": dict['noun'], "class_id": dict['noun_class'], "handbox": 1, "narration": 1, "positive": 1}} 97 | list_of_dict.append(dict) 98 | all_sparse_frames += len(dict['sparse_frames']) 99 | 100 | # save json 101 | json_object = json.dumps(list_of_dict) 102 | with open(os.path.join(SAVE_PATH,'ImageSets',split+'_'+sampling+".json"), "w") as f: 103 | f.write(json_object) 104 | print('finished for {}_{} set. {:d} actions, {:d} sparse frames'.format(split,sampling,seq_id,all_sparse_frames)) 105 | 106 | def main(): 107 | generate_action_with_mask('val') 108 | 109 | if __name__ == '__main__': 110 | main() -------------------------------------------------------------------------------- /dataset_visor/ImageSets/val_novel.json: -------------------------------------------------------------------------------- 1 | [{"seq_id": 177, "video": "P02_09", "start": 23199, "narration": "add flour to mix", "verb": "add", "verb_class": 46, "noun": "flour", "noun_class": 75, "end": 23397, "sparse_frames": ["frame_0000023240.jpg", "frame_0000023348.jpg"], "object_classes": {"1": {"name": "bowl", "class_id": 7, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "flour", "class_id": 75, "handbox": 0, "narration": 1, "positive": 1}, "3": {"name": "flour package", "class_id": 19, "handbox": 0, "narration": 0, "positive": 1}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "mixture", "class_id": 71, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "oil bottle", "class_id": 15, "handbox": 0, "narration": 0, "positive": 0}, "7": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 187, "video": "P02_09", "start": 28213, "narration": "form bowl from mixture", "verb": "form-from", "verb_class": 74, "noun": "bowl", "noun_class": 7, "end": 29419, "sparse_frames": ["frame_0000028386.jpg", "frame_0000028439.jpg"], "object_classes": {"1": {"name": "bowl", "class_id": 7, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "mixture", "class_id": 71, "handbox": 1, "narration": 1, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 382, "video": "P02_09", "start": 101080, "narration": "carry glass", "verb": "carry", "verb_class": 89, "noun": "glass", "noun_class": 10, "end": 101233, "sparse_frames": ["frame_0000101114.jpg", "frame_0000101203.jpg"], "object_classes": {"1": {"name": "chicken stock", "class_id": 95, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "glass", "class_id": 10, "handbox": 1, "narration": 1, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "spoon", "class_id": 1, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 578, "video": "P02_12", "start": 15777, "narration": "get thyme", "verb": "get", "verb_class": 0, "noun": "thyme", "noun_class": 215, "end": 15907, "sparse_frames": ["frame_0000015798.jpg", "frame_0000015857.jpg"], "object_classes": {"1": {"name": "cupboard", "class_id": 3, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "thyme jar", "class_id": 40, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 724, "video": "P02_12", "start": 66179, "narration": "mix thyme and olive oil in plate", "verb": "mix", "verb_class": 10, "noun": "thyme", "noun_class": 215, "end": 66528, "sparse_frames": ["frame_0000066244.jpg", "frame_0000066437.jpg"], "object_classes": {"1": {"name": "dough ball", "class_id": 25, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "jar", "class_id": 40, "handbox": 0, "narration": 0, "positive": 0}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "mixture", "class_id": 71, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "plate", "class_id": 2, "handbox": 1, "narration": 1, "positive": 1}, "6": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "7": {"name": "spoon", "class_id": 1, "handbox": 0, "narration": 0, "positive": 0}, "8": {"name": "thyme", "class_id": 215, "handbox": 0, "narration": 1, "positive": 1}, "9": {"name": "tray", "class_id": 35, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 733, "video": "P02_12", "start": 69147, "narration": "scoop olive oil and thyme", "verb": "scoop", "verb_class": 16, "noun": "oil:olive", "noun_class": 31, "end": 69297, "sparse_frames": ["frame_0000069161.jpg", "frame_0000069251.jpg"], "object_classes": {"1": {"name": "dough ball", "class_id": 25, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "mixture", "class_id": 71, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "pie", "class_id": 137, "handbox": 1, "narration": 0, "positive": 0}, "5": {"name": "plate", "class_id": 2, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "7": {"name": "tray", "class_id": 35, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 765, "video": "P02_12", "start": 78409, "narration": "push oven tray", "verb": "push", "verb_class": 21, "noun": "tray:oven", "noun_class": 35, "end": 78532, "sparse_frames": ["frame_0000078436.jpg", "frame_0000078499.jpg"], "object_classes": {"1": {"name": "drawer", "class_id": 8, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "left glove", "class_id": 303, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "oven", "class_id": 46, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right glove", "class_id": 304, "handbox": 0, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 792, "video": "P03_120", "start": 5652, "narration": "tap weight tray", "verb": "tap", "verb_class": 32, "noun": "tray:weight", "noun_class": 35, "end": 5826, "sparse_frames": ["frame_0000005672.jpg", "frame_0000005697.jpg", "frame_0000005722.jpg", "frame_0000005747.jpg", "frame_0000005772.jpg", "frame_0000005797.jpg", "frame_0000005822.jpg", "frame_0000005826.jpg"], "object_classes": {"1": {"name": "bowl", "class_id": 7, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "sugar", "class_id": 102, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "weight tray", "class_id": 35, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 926, "video": "P04_06", "start": 5491, "narration": "pick up spinach", "verb": "pick-up", "verb_class": 0, "noun": "spinach", "noun_class": 198, "end": 5764, "sparse_frames": ["frame_0000005559.jpg", "frame_0000005679.jpg"], "object_classes": {"1": {"name": "food", "class_id": 34, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "fridge", "class_id": 12, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "hob/cooktop/stovetop", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "pan", "class_id": 5, "handbox": 1, "narration": 0, "positive": 0}, "6": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "7": {"name": "spinach", "class_id": 198, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1024, "video": "P04_13", "start": 10393, "narration": "look at cup", "verb": "look-at", "verb_class": 38, "noun": "cup", "noun_class": 13, "end": 10564, "sparse_frames": ["frame_0000010440.jpg", "frame_0000010516.jpg"], "object_classes": {"1": {"name": "cup", "class_id": 13, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1078, "video": "P04_24", "start": 24013, "narration": "cut beef", "verb": "cut", "verb_class": 7, "noun": "beef", "noun_class": 217, "end": 25609, "sparse_frames": ["frame_0000024417.jpg", "frame_0000025240.jpg"], "object_classes": {"1": {"name": "beef", "class_id": 217, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "cooker bowl", "class_id": 7, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "scissors", "class_id": 79, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "spatula", "class_id": 20, "handbox": 1, "narration": 0, "positive": 0}, "7": {"name": "spice paste", "class_id": 148, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1096, "video": "P06_03", "start": 2511, "narration": "check flour", "verb": "check", "verb_class": 24, "noun": "flour", "noun_class": 75, "end": 2826, "sparse_frames": ["frame_0000002598.jpg", "frame_0000002744.jpg"], "object_classes": {"1": {"name": "flour", "class_id": 75, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "flour packet", "class_id": 26, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "pot", "class_id": 29, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1142, "video": "P06_03", "start": 19775, "narration": "drop oil", "verb": "drop", "verb_class": 59, "noun": "oil", "noun_class": 31, "end": 19854, "sparse_frames": ["frame_0000019790.jpg", "frame_0000019846.jpg"], "object_classes": {"1": {"name": "ingredient", "class_id": 34, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "lid", "class_id": 6, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "oil", "class_id": 31, "handbox": 1, "narration": 1, "positive": 1}, "5": {"name": "oil bottle", "class_id": 15, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "pot", "class_id": 29, "handbox": 1, "narration": 0, "positive": 0}, "7": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1299, "video": "P06_05", "start": 68698, "narration": "put bacon in fridge", "verb": "put-in", "verb_class": 5, "noun": "bacon", "noun_class": 163, "end": 68881, "sparse_frames": ["frame_0000068746.jpg", "frame_0000068851.jpg"], "object_classes": {"1": {"name": "bacon package", "class_id": 26, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "fridge", "class_id": 12, "handbox": 1, "narration": 1, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1358, "video": "P06_10", "start": 3949, "narration": "open potatoes", "verb": "open", "verb_class": 3, "noun": "potato", "noun_class": 30, "end": 4583, "sparse_frames": ["frame_0000004076.jpg", "frame_0000004403.jpg"], "object_classes": {"1": {"name": "cupboard", "class_id": 3, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "glove", "class_id": 60, "handbox": 0, "narration": 0, "positive": 0}, "3": {"name": "hob", "class_id": 24, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "lid", "class_id": 6, "handbox": 0, "narration": 0, "positive": 0}, "6": {"name": "pot", "class_id": 29, "handbox": 0, "narration": 0, "positive": 0}, "7": {"name": "potato package", "class_id": 26, "handbox": 1, "narration": 0, "positive": 1}, "8": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "9": {"name": "scissors", "class_id": 79, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1391, "video": "P06_106", "start": 418, "narration": "shake rice", "verb": "shake", "verb_class": 15, "noun": "rice", "noun_class": 55, "end": 515, "sparse_frames": ["frame_0000000451.jpg", "frame_0000000484.jpg"], "object_classes": {"1": {"name": "fork", "class_id": 14, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "hob", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "lid", "class_id": 6, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "pan", "class_id": 5, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "rice", "class_id": 55, "handbox": 1, "narration": 1, "positive": 1}, "6": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1491, "video": "P06_108", "start": 25309, "narration": "adjust foil", "verb": "adjust", "verb_class": 17, "noun": "foil", "noun_class": 123, "end": 25410, "sparse_frames": ["frame_0000025344.jpg", "frame_0000025388.jpg"], "object_classes": {"1": {"name": "hob", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "pan", "class_id": 5, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "pizza", "class_id": 91, "handbox": 1, "narration": 0, "positive": 0}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "silver foil", "class_id": 123, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1494, "video": "P06_108", "start": 26271, "narration": "take out phone", "verb": "take-out", "verb_class": 12, "noun": "phone", "noun_class": 165, "end": 26458, "sparse_frames": ["frame_0000026317.jpg", "frame_0000026398.jpg"], "object_classes": {"1": {"name": "hob", "class_id": 24, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "pan", "class_id": 5, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "phone", "class_id": 165, "handbox": 1, "narration": 1, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "silver foil", "class_id": 123, "handbox": 0, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1506, "video": "P06_108", "start": 30212, "narration": "fill bowl with water", "verb": "fill-with", "verb_class": 26, "noun": "bowl", "noun_class": 7, "end": 30597, "sparse_frames": ["frame_0000030226.jpg", "frame_0000030254.jpg", "frame_0000030301.jpg", "frame_0000030576.jpg"], "object_classes": {"1": {"name": "bowl", "class_id": 7, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "can opener", "class_id": 174, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "dish soap", "class_id": 22, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 0}, "6": {"name": "sink", "class_id": 63, "handbox": 1, "narration": 0, "positive": 0}, "7": {"name": "tap", "class_id": 0, "handbox": 1, "narration": 0, "positive": 0}, "8": {"name": "water", "class_id": 27, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1630, "video": "P07_110", "start": 11596, "narration": "shake glass", "verb": "shake", "verb_class": 15, "noun": "glass", "noun_class": 10, "end": 11682, "sparse_frames": ["frame_0000011609.jpg", "frame_0000011643.jpg"], "object_classes": {"1": {"name": "glass", "class_id": 10, "handbox": 0, "narration": 1, "positive": 1}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "mixer", "class_id": 82, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "mixer base", "class_id": 82, "handbox": 0, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1780, "video": "P09_07", "start": 150, "narration": "open mixer", "verb": "open", "verb_class": 3, "noun": "mixer", "noun_class": 82, "end": 250, "sparse_frames": ["frame_0000000181.jpg", "frame_0000000227.jpg"], "object_classes": {"1": {"name": "drawer", "class_id": 8, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "mixer", "class_id": 82, "handbox": 1, "narration": 1, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1782, "video": "P09_07", "start": 417, "narration": "take supplement", "verb": "take", "verb_class": 0, "noun": "supplement", "noun_class": 171, "end": 488, "sparse_frames": ["frame_0000000441.jpg", "frame_0000000471.jpg"], "object_classes": {"1": {"name": "cupboard", "class_id": 3, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "drawer", "class_id": 8, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "supplement", "class_id": 171, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1812, "video": "P09_103", "start": 12149, "narration": "add condiments to chicken", "verb": "add-to", "verb_class": 46, "noun": "condiment", "noun_class": 47, "end": 12472, "sparse_frames": ["frame_0000012195.jpg", "frame_0000012401.jpg"], "object_classes": {"1": {"name": "chicken", "class_id": 57, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "chicken package", "class_id": 26, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "condiment", "class_id": 47, "handbox": 1, "narration": 1, "positive": 1}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 1857, "video": "P09_104", "start": 12941, "narration": "add oil to pan", "verb": "add-to", "verb_class": 46, "noun": "oil", "noun_class": 31, "end": 13099, "sparse_frames": ["frame_0000012976.jpg", "frame_0000013063.jpg"], "object_classes": {"1": {"name": "chorizo", "class_id": 86, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "oil bottle", "class_id": 15, "handbox": 0, "narration": 0, "positive": 1}, "4": {"name": "pan", "class_id": 5, "handbox": 0, "narration": 0, "positive": 1}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "stove", "class_id": 24, "handbox": 0, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2017, "video": "P18_01", "start": 8551, "narration": "separate them", "verb": "separate", "verb_class": 52, "noun": "machine:coffee", "noun_class": 50, "end": 8656, "sparse_frames": ["frame_0000008596.jpg", "frame_0000008630.jpg"], "object_classes": {"1": {"name": "bread package", "class_id": 26, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "plate", "class_id": 2, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "spreading knife", "class_id": 4, "handbox": 1, "narration": 0, "positive": 0}, "6": {"name": "toast bread", "class_id": 33, "handbox": 1, "narration": 0, "positive": 1}, "7": {"name": "toaster", "class_id": 186, "handbox": 0, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2049, "video": "P21_01", "start": 5983, "narration": "selecting the power", "verb": "select", "verb_class": 85, "noun": "power", "noun_class": 278, "end": 6352, "sparse_frames": ["frame_0000006056.jpg", "frame_0000006276.jpg"], "object_classes": {"1": {"name": "microwave", "class_id": 90, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2053, "video": "P21_01", "start": 8338, "narration": "taking out the grapes", "verb": "take-out", "verb_class": 12, "noun": "grape", "noun_class": 202, "end": 8418, "sparse_frames": ["frame_0000008352.jpg", "frame_0000008388.jpg"], "object_classes": {"1": {"name": "grape", "class_id": 202, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "grocery", "class_id": 34, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "shopping bag", "class_id": 19, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2213, "video": "P22_107", "start": 20890, "narration": "put down blackberry", "verb": "put-down", "verb_class": 1, "noun": "blackberry", "noun_class": 219, "end": 21039, "sparse_frames": ["frame_0000020934.jpg", "frame_0000020996.jpg"], "object_classes": {"1": {"name": "blackberry", "class_id": 219, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "bowl", "class_id": 7, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "sink", "class_id": 63, "handbox": 1, "narration": 0, "positive": 0}, "6": {"name": "spoon", "class_id": 1, "handbox": 1, "narration": 0, "positive": 0}, "7": {"name": "yogurt", "class_id": 98, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2217, "video": "P22_107", "start": 21731, "narration": "rinse blackberries", "verb": "rinse", "verb_class": 2, "noun": "blackberry", "noun_class": 219, "end": 21803, "sparse_frames": ["frame_0000021755.jpg", "frame_0000021768.jpg", "frame_0000021781.jpg"], "object_classes": {"1": {"name": "blackberry", "class_id": 219, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "sink", "class_id": 63, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "tap", "class_id": 0, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "water", "class_id": 27, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2320, "video": "P24_09", "start": 42609, "narration": "move potatoes to the pan", "verb": "move-to", "verb_class": 11, "noun": "potato", "noun_class": 30, "end": 43044, "sparse_frames": ["frame_0000042728.jpg", "frame_0000042962.jpg"], "object_classes": {"1": {"name": "chopping board", "class_id": 18, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "egg", "class_id": 53, "handbox": 0, "narration": 0, "positive": 0}, "3": {"name": "fork", "class_id": 14, "handbox": 0, "narration": 0, "positive": 0}, "4": {"name": "hob", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "5": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "onion", "class_id": 16, "handbox": 0, "narration": 0, "positive": 0}, "7": {"name": "pan", "class_id": 5, "handbox": 1, "narration": 1, "positive": 1}, "8": {"name": "plate", "class_id": 2, "handbox": 0, "narration": 0, "positive": 0}, "9": {"name": "potato", "class_id": 30, "handbox": 1, "narration": 1, "positive": 1}, "10": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "11": {"name": "wooden spoon", "class_id": 1, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2355, "video": "P24_09", "start": 59121, "narration": "move potato in pan", "verb": "move", "verb_class": 11, "noun": "potato", "noun_class": 30, "end": 59889, "sparse_frames": ["frame_0000059334.jpg", "frame_0000059697.jpg"], "object_classes": {"1": {"name": "egg", "class_id": 53, "handbox": 1, "narration": 0, "positive": 0}, "2": {"name": "fork", "class_id": 14, "handbox": 0, "narration": 0, "positive": 0}, "3": {"name": "hob", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "lid", "class_id": 6, "handbox": 0, "narration": 0, "positive": 0}, "6": {"name": "pan", "class_id": 5, "handbox": 1, "narration": 1, "positive": 1}, "7": {"name": "plate", "class_id": 2, "handbox": 1, "narration": 0, "positive": 0}, "8": {"name": "potato", "class_id": 30, "handbox": 1, "narration": 1, "positive": 1}, "9": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "10": {"name": "wooden spoon", "class_id": 1, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2724, "video": "P26_02", "start": 13282, "narration": "turn chicken breast", "verb": "turn", "verb_class": 23, "noun": "breast:chicken", "noun_class": 57, "end": 13491, "sparse_frames": ["frame_0000013346.jpg", "frame_0000013419.jpg"], "object_classes": {"1": {"name": "chicken breast", "class_id": 57, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "fork", "class_id": 14, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "hob", "class_id": 24, "handbox": 1, "narration": 0, "positive": 0}, "4": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "plate", "class_id": 2, "handbox": 0, "narration": 0, "positive": 0}, "6": {"name": "pot", "class_id": 29, "handbox": 1, "narration": 0, "positive": 1}, "7": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2914, "video": "P27_105", "start": 42644, "narration": "tear aluminium foil", "verb": "tear", "verb_class": 43, "noun": "foil:aluminium", "noun_class": 123, "end": 43232, "sparse_frames": ["frame_0000042792.jpg", "frame_0000043151.jpg"], "object_classes": {"1": {"name": "aluminium foil", "class_id": 123, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "box", "class_id": 23, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "paper piece", "class_id": 49, "handbox": 0, "narration": 0, "positive": 0}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "small sandwich", "class_id": 164, "handbox": 0, "narration": 0, "positive": 0}, "7": {"name": "table", "class_id": 42, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2918, "video": "P27_105", "start": 44024, "narration": "tear aluminium foil", "verb": "tear", "verb_class": 43, "noun": "foil:aluminium", "noun_class": 123, "end": 44506, "sparse_frames": ["frame_0000044181.jpg", "frame_0000044350.jpg"], "object_classes": {"1": {"name": "aluminium foil", "class_id": 123, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "box", "class_id": 23, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "paper piece", "class_id": 49, "handbox": 0, "narration": 0, "positive": 0}, "5": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "small sandwich", "class_id": 164, "handbox": 0, "narration": 0, "positive": 0}, "7": {"name": "table", "class_id": 42, "handbox": 1, "narration": 0, "positive": 0}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 2963, "video": "P28_05", "start": 16105, "narration": "put spring onions in tray", "verb": "put-in", "verb_class": 5, "noun": "onion:spring", "noun_class": 128, "end": 16257, "sparse_frames": ["frame_0000016135.jpg", "frame_0000016209.jpg"], "object_classes": {"1": {"name": "chopping board", "class_id": 18, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "knife", "class_id": 4, "handbox": 1, "narration": 0, "positive": 0}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "spring onion piece", "class_id": 128, "handbox": 0, "narration": 1, "positive": 1}, "6": {"name": "tray", "class_id": 35, "handbox": 1, "narration": 1, "positive": 1}, "7": {"name": "vegetable", "class_id": 94, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 3005, "video": "P29_04", "start": 7248, "narration": "scoop tea into teapot", "verb": "scoop-into", "verb_class": 16, "noun": "tea", "noun_class": 132, "end": 7675, "sparse_frames": ["frame_0000007373.jpg", "frame_0000007612.jpg"], "object_classes": {"1": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "2": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "tea", "class_id": 132, "handbox": 1, "narration": 1, "positive": 1}, "4": {"name": "tea box", "class_id": 23, "handbox": 0, "narration": 0, "positive": 0}, "5": {"name": "tea packet", "class_id": 26, "handbox": 1, "narration": 0, "positive": 1}, "6": {"name": "teapot", "class_id": 158, "handbox": 1, "narration": 1, "positive": 1}, "7": {"name": "teapot lid", "class_id": 6, "handbox": 1, "narration": 0, "positive": 1}, "8": {"name": "teaspoon", "class_id": 1, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 3151, "video": "P32_07", "start": 1464, "narration": "close toaster", "verb": "close", "verb_class": 4, "noun": "toaster", "noun_class": 186, "end": 1605, "sparse_frames": ["frame_0000001504.jpg", "frame_0000001566.jpg"], "object_classes": {"1": {"name": "bread package", "class_id": 26, "handbox": 0, "narration": 0, "positive": 0}, "2": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "3": {"name": "right hand", "class_id": 301, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "sandwich bread", "class_id": 33, "handbox": 0, "narration": 0, "positive": 0}, "5": {"name": "toaster", "class_id": 186, "handbox": 1, "narration": 1, "positive": 1}}, "redundant": false, "missing": false, "other": false}, {"seq_id": 3221, "video": "P37_102", "start": 27527, "narration": "shake and rinse chicken thighs", "verb": "shake", "verb_class": 15, "noun": "thigh:chicken", "noun_class": 57, "end": 27800, "sparse_frames": ["frame_0000027582.jpg", "frame_0000027740.jpg"], "object_classes": {"1": {"name": "chicken thigh", "class_id": 57, "handbox": 1, "narration": 1, "positive": 1}, "2": {"name": "drainer", "class_id": 110, "handbox": 0, "narration": 0, "positive": 0}, "3": {"name": "left hand", "class_id": 300, "handbox": 1, "narration": 0, "positive": 1}, "4": {"name": "pot", "class_id": 29, "handbox": 1, "narration": 0, "positive": 1}, "5": {"name": "water", "class_id": 27, "handbox": 1, "narration": 0, "positive": 1}}, "redundant": false, "missing": false, "other": false}] -------------------------------------------------------------------------------- /demo_path/ImageSets/expression_file.json: -------------------------------------------------------------------------------- 1 | {"videos":{"00000012_P01_107_put-into_bag:cereal": {"expressions": {"0": {"exp": "bread used in the action of put cereals bag into cereals box"}, "1": {"exp": "cereal used in the action of put cereals bag into cereals box"}, "2": {"exp": "cereal bag used in the action of put cereals bag into cereals box"}, "3": {"exp": "cereal box used in the action of put cereals bag into cereals box"}, "4": {"exp": "cup used in the action of put cereals bag into cereals box"}, "5": {"exp": "left hand used in the action of put cereals bag into cereals box"}, "6": {"exp": "right hand used in the action of put cereals bag into cereals box"}, "7": {"exp": "soy milk used in the action of put cereals bag into cereals box"}}, "frames": ["frame_0000002521", "frame_0000002559"]},"00000223_P02_09_pick-up_spoon": {"expressions": {"0": {"exp": "lid used in the action of pick up two spoons"}, "1": {"exp": "oil bottle used in the action of pick up two spoons"}, "2": {"exp": "right hand used in the action of pick up two spoons"}, "3": {"exp": "spoon used in the action of pick up two spoons"}}, "frames": ["frame_0000040843", "frame_0000040873"]}}} -------------------------------------------------------------------------------- /demo_path/JPEGImages_Sparse/val/00000012_P01_107_put-into_bag:cereal/frame_0000002521.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/demo_path/JPEGImages_Sparse/val/00000012_P01_107_put-into_bag:cereal/frame_0000002521.jpg -------------------------------------------------------------------------------- /demo_path/JPEGImages_Sparse/val/00000012_P01_107_put-into_bag:cereal/frame_0000002559.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/demo_path/JPEGImages_Sparse/val/00000012_P01_107_put-into_bag:cereal/frame_0000002559.jpg -------------------------------------------------------------------------------- /demo_path/JPEGImages_Sparse/val/00000223_P02_09_pick-up_spoon/frame_0000040843.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/demo_path/JPEGImages_Sparse/val/00000223_P02_09_pick-up_spoon/frame_0000040843.jpg -------------------------------------------------------------------------------- /demo_path/JPEGImages_Sparse/val/00000223_P02_09_pick-up_spoon/frame_0000040873.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/demo_path/JPEGImages_Sparse/val/00000223_P02_09_pick-up_spoon/frame_0000040873.jpg -------------------------------------------------------------------------------- /figures/ActionVOS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/figures/ActionVOS.png -------------------------------------------------------------------------------- /figures/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/figures/method.png -------------------------------------------------------------------------------- /figures/weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ut-vision/ActionVOS/1a12aae100d0a84adcba72f859d81398e10e839a/figures/weights.png --------------------------------------------------------------------------------