├── .gitignore
├── README.md
├── RF_ActionVOS
├── __init__.py
├── actionvos.py
├── configs.md
├── criterion.py
├── inference_actionvos.py
├── main_actionvos.py
├── opts.py
├── referformer.py
├── segmentation.py
├── test_actionvos.sh
├── train_actionvos.sh
└── transforms_video_actionvos.py
├── actionvos_metrics.py
├── annotations
├── 00000.png
├── EPIC_100_train.csv
└── EPIC_100_validation.csv
├── copy_rf_actionvos_files.py
├── data_prepare_visor.py
├── data_prepare_vost.py
├── data_prepare_vscos.py
├── dataset_visor
└── ImageSets
│ ├── val_human.json
│ └── val_novel.json
├── demo_path
├── ImageSets
│ └── expression_file.json
└── JPEGImages_Sparse
│ └── val
│ ├── 00000012_P01_107_put-into_bag:cereal
│ ├── frame_0000002521.jpg
│ └── frame_0000002559.jpg
│ └── 00000223_P02_09_pick-up_spoon
│ ├── frame_0000040843.jpg
│ └── frame_0000040873.jpg
└── figures
├── ActionVOS.png
├── method.png
└── weights.png
/.gitignore:
--------------------------------------------------------------------------------
1 | annotations/visor_hos_train.json
2 | annotations/visor_hos_val.json
3 | dataset_visor/ImageSets/train.json
4 | dataset_visor/ImageSets/val.json
5 | dataset_visor/ImageSets/train_objects_category.json
6 | dataset_visor/ImageSets/val_objects_category.json
7 | dataset_visor/ImageSets/train_meta_expressions_promptaction.json
8 | dataset_visor/ImageSets/val_meta_expressions_promptaction.json
9 | ReferFormer/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **ActionVOS: Actions as Prompts for Video Object Segmentation**
2 |
3 | Our [paper](https://arxiv.org/abs/2407.07402) is accepted by **ECCV-2024** as [**oral**](https://eccv.ecva.net/virtual/2024/oral/1604) **(2.3%)** presentation!
4 |
5 |
6 |
7 | **Picture:** *Overview of the proposed ActionVOS setting.*
8 |
9 |
10 |
11 | **Picture:** *The proposed method in our paper.*
12 |
13 | ---
14 |
15 | This repository contains the official PyTorch implementation of the following paper:
16 |
17 | > **ActionVOS: Actions as Prompts for Video Object Segmentation**
18 | Liangyang Ouyang, Ruicong Liu, Yifei Huang, Ryosuke Furuta, and Yoichi Sato
19 | >
20 | >**Abstract:**
21 | Delving into the realm of egocentric vision, the advancement of referring video object segmentation (RVOS) stands as pivotal in understanding human activities. However, existing RVOS task primarily relies on static attributes such as object names to segment target objects, posing challenges in distinguishing target objects from background objects and in identifying objects undergoing state changes. To address these problems, this work proposes a novel action-aware RVOS setting called ActionVOS, aiming at segmenting only active objects in egocentric videos using human actions as a key language prompt. This is because human actions precisely describe the behavior of humans, thereby helping to identify the objects truly involved in the interaction and to understand possible state changes. We also build a method tailored to work under this specific setting. Specifically, we develop an action-aware labeling module with an efficient action-guided focal loss. Such designs enable ActionVOS model to prioritize active objects with existing readily-available annotations. Experimental results on VISOR dataset reveal that ActionVOS significantly reduces the mis-segmentation of inactive objects, confirming that actions help the ActionVOS model understand objects' involvement. Further evaluations on VOST and VSCOS datasets show that the novel ActionVOS setting enhances segmentation performance when encountering challenging circumstances involving object state changes.
22 |
23 | ## Resources
24 |
25 | Material related to our paper is available via the following links:
26 |
27 | - [**Paper**](https://arxiv.org/abs/2407.07402)
28 | - [**Video**](https://youtu.be/dt-zDQKzq1I)
29 | - [VISOR dataset](https://epic-kitchens.github.io/VISOR/)
30 | - [VOST dataset](https://www.vostdataset.org/data.html)
31 | - [VSCOS dataset](https://github.com/venom12138/VSCOS)
32 | - [ReferFormer Model](https://github.com/wjn922/ReferFormer)
33 |
34 | ## Requirements
35 |
36 | * Our experiment is tested with Python 3.8, PyTorch 1.11.0.
37 | * Our experiment with RerferFormer used 4 V100 GPUs, and 6-12 hours for train 6 epochs on VISOR.
38 | * Check **Training** instructions for necessary packages of RF.
39 |
40 | ## Playing with ActionVOS
41 |
42 | ### **Data preparation (Pseudo-labeling and Weight-generation)**
43 |
44 | For the videos and masks, please download VISOR-VOS,VSCOS,VOST dataset from these links. We recommend to download VISOR-VOS first since we use VISOR-VOS for both training and testing.
45 |
46 | - [**VISOR-VOS (28.4GB)**](https://data.bris.ac.uk/data/dataset/2v6cgv1x04ol22qp9rm9x2j6a7)
47 | - [VSCOS (20GB)](https://github.com/venom12138/VSCOS)
48 | - [VOST (50GB)](https://www.vostdataset.org/data.html)
49 |
50 | [Action narration annotations](./annotations/EPIC_100_train.csv) are obtained from [EK-100](https://github.com/epic-kitchens/epic-kitchens-100-annotations). (We already put them in this repository so you don't need to download it.)
51 |
52 | [Hand-object annotations](./annotations/visor_hos_train.json) are obtained from [VISOR-HOS](https://github.com/epic-kitchens/VISOR-HOS). (Please download from google drive [link1](https://drive.google.com/file/d/1Op-QtoweJ-2M0nuMqtbBHAsJ4Ep-g6nU/view?usp=sharing), [link2](https://drive.google.com/file/d/1KkQ-BOC4E0P087D2hyTN9eUxMmNPq_Ot/view?usp=sharing) and put them under /annotations.)
53 |
54 | Then run data_prepare_visor.py to get data,annotation,action-aware pseudo-labels and action-guided weights for ActionVOS.
55 |
56 | ```
57 | python data_prepare_visor.py --VISOR_PATH your_visor_epick_path
58 | ```
59 |
60 | It takes 1-2 hours for processing data. After that, the folder dataset_visor will get structure of:
61 |
62 | ```
63 | - dataset_visor
64 | - Annotations_Sparse
65 | - train
66 | - 00000001_xxx
67 | - obj_masks.png
68 | - 00000002_xxx
69 | - val
70 | - JPEGImages_Sparse
71 | - train
72 | - 00000001_xxx
73 | - rgb_frames.jpg
74 | - 00000002_xxx
75 | - val
76 | - Weights_Sparse
77 | - train
78 | - 00000001_xxx
79 | - action-guided-weights.png
80 | - 00000002_xxx
81 | - val (not used)
82 | - ImageSets
83 | - train.json
84 | - val.json
85 | - val_human.json
86 | - val_novel.json
87 | ```
88 |
89 | There are 2 special files val_human.json and val_novel.json. These files contains the split that used for results in our experiments, where val_human contains the actions annotated by human, val_novel contains actions that unseen in the validation set.
90 |
91 | ### **How to find action-aware pseudo labels**
92 |
93 | Check [train.json](./dataset_visor/ImageSets/train.json). For each object name in each video, the json file contains a map such as {"name": "food container", "class_id": 21, "handbox": 0, "narration": 1, "positive": 1}.
94 |
95 | handbox = 1 for object mask intersects with hand-object bounding boxes.
96 |
97 | narration = 1 for object name mentioned in action narration.
98 |
99 | positive = 1 for pseudo positive object.
100 |
101 | Note that object masks under Annotations_Sparse are for all objects. We combine them with class labels in experiments.
102 |
103 | ### **How to find action-guided weights**
104 |
105 | Each picture under Weights_Sparse is an action-guided weight.
106 |
107 |
108 |
109 | **Picture:** *Action-guided Weights*
110 |
111 | ```
112 | 3 (yellow) for negative obj mask.
113 | 2 (green) for hand | narration obj mask.
114 | 4 (blue) for hand & narration obj mask.
115 | 1 (red) for other areas
116 | ```
117 |
118 | ### **Training**
119 |
120 | ActionVOS is an action-aware setting for RVOS, and any RVOS model with an extra class head can be trained for ActionVOS. In our experiments, we take ReferFormer-ResNet101 as the base RVOS model.
121 |
122 | Clone [ReferFormer](https://github.com/wjn922/ReferFormer) repository and download their [pretrained checkpoints](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/wjn922_connect_hku_hk/EShgDd650nBBsfoNEiUbybcB84Ma5NydxOucISeCrZmzHw?e=YOSszd).
123 | ```
124 | git clone https://github.com/wjn922/ReferFormer.git
125 | cd ReferFormer
126 | mkdir pretrained_weights
127 | download from the link
128 | ```
129 |
130 | Install the necessary packages for ReferFormer.
131 |
132 | ```
133 | cd ReferFormer
134 | pip install -r requirements.txt
135 | pip install 'git+https://github.com/facebookresearch/fvcore'
136 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
137 | cd models/ops
138 | python setup.py build install
139 | ```
140 |
141 | Put modificated files to ReferFormer folders.
142 |
143 | ```
144 | python copy_rf_actionvos_files.py
145 | ```
146 |
147 | Run training scripts. If you want to change training configs, check [RF_ActionVOS/configs.md](RF_ActionVOS/configs.md). The following example shows training actionvos on a single GPU 0.
148 |
149 | ```
150 | cd ReferFormer
151 | bash scripts/train_actionvos.sh actionvos_dirs/r101 pretrained_weights/r101_refytvos_joint.pth 1 0 29500 --backbone resnet101 --expression_file train_meta_expressions_promptaction.json --use_weights --use_positive_cls --actionvos_path ../dataset_visor --epochs 6 --lr_drop 3 5 --save_interval 3
152 | ```
153 |
154 | After the training process, the weights will be saved to actionvos_dirs/r101/checkpoint.pth.
155 |
156 | ### **Inference**
157 |
158 | For quick start to ActionVOS models, we offer a trained RF-R101 checkpoint in [this link](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing).
159 |
160 | #### **Inference on VISOR**
161 |
162 | ```
163 | cd ReferFormer
164 | bash scripts/test_actionvos.sh actionvos_dirs/r101 pretrained_weights/actionvos_rf_r101.pth 0 29500 --backbone resnet101 --expression_file val_meta_expressions_promptaction.json --use_positive_cls --pos_cls_thres 0.75 --actionvos_path ../dataset_visor
165 | ```
166 |
167 | The output masks will be saved in ReferFormer/actionvos_dirs/r101/val.
168 |
169 | #### **Inference on your own videos and prompts**
170 |
171 | Change your videos and prompts into a actionvos_path like
172 |
173 | ```
174 | - demo_path
175 | - JPEGImages_Sparse
176 | - val
177 | - video_name
178 | - rgb_frames.jpg
179 | - ImageSets
180 | - expression_file.json
181 | ```
182 |
183 | Check the [example json file](demo_path/ImageSets/expression_file.json) for the prompt formats.
184 |
185 | ```
186 | cd ReferFormer
187 | bash scripts/test_actionvos.sh actionvos_dirs/demo pretrained_weights/actionvos_rf_r101.pth 0 29500 --backbone resnet101 --expression_file expression_file.json --use_positive_cls --pos_cls_thres 0.75 --actionvos_path ../demo_path
188 | ```
189 |
190 | The output masks will be saved in ReferFormer/actionvos_dirs/demo/val.
191 |
192 | #### **Evaluation Metrics**
193 |
194 | We use 6 metrics, p-mIoU, n-mIoU, p-cIoU, n-cIoU, gIoU and accuracy to evaluate ActionVOS performance on VISOR val_human split.
195 |
196 | ```
197 | python actionvos_metrics.py --pred_path ReferFormer/actionvos_dirs/r101/val --gt_path dataset_visor/Annotations_Sparse/val --split_json dataset_visor/ImageSets/val_human.json
198 | ```
199 |
200 | If you correctly generated object masks by [this checkpoint](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing), you should get results below:
201 |
202 | | Model | Split | p-mIoU | n-mIoU | p-cIoU | n-cIoU | gIoU | Acc |
203 | |----------|-----------|-----------|---------|-----------|-----------|---------|-----------|
204 | | [RF_R101](https://drive.google.com/file/d/140gfK4GkI5iBSVFqoi_CAfL6d0J39nOW/view?usp=sharing) | val_human* | 66.1 | 18.6 | 72.7 | 32.2 | 71.2 | 83.0 |
205 |
206 | \* Note that the val_human here only use 294 videos. Check [actionvos_metrics.py](actionvos_metrics.py) for details.
207 |
208 | ## Citation
209 |
210 | If this work or code is helpful in your research, please cite:
211 |
212 | ```latex
213 | @inproceedings{ouyang2024actionvos,
214 | title={ActionVOS: Actions as Prompts for Video Object Segmentation},
215 | author={Ouyang, Liangyang and Liu, Ruicong and Huang, Yifei and Furuta, Ryosuke and Sato, Yoichi},
216 | booktitle={European Conference on Computer Vision},
217 | pages={216--235},
218 | year={2024}
219 | }
220 | ```
221 |
222 | If you are using the data and annotations from [VISOR](https://proceedings.neurips.cc/paper_files/paper/2022/hash/590a7ebe0da1f262c80d0188f5c4c222-Abstract-Datasets_and_Benchmarks.html),[VSCOS](https://openaccess.thecvf.com/content/ICCV2023/html/Yu_Video_State-Changing_Object_Segmentation_ICCV_2023_paper.html),[VOST](https://openaccess.thecvf.com/content/CVPR2023/html/Tokmakov_Breaking_the_Object_in_Video_Object_Segmentation_CVPR_2023_paper.html), please cite their original paper.
223 |
224 | If you are using the training, inference and evaluation code, please cite [ReferFormer](https://openaccess.thecvf.com/content/CVPR2022/html/Wu_Language_As_Queries_for_Referring_Video_Object_Segmentation_CVPR_2022_paper.html) and [GRES](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GRES_Generalized_Referring_Expression_Segmentation_CVPR_2023_paper.html).
225 |
226 |
227 | ## Contact
228 |
229 | For any questions, including algorithms and datasets, feel free to contact me by email: `oyly(at)iis.u-tokyo.ac.jp`
--------------------------------------------------------------------------------
/RF_ActionVOS/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | import torchvision
3 |
4 | from .ytvos import build as build_ytvos
5 | from .davis import build as build_davis
6 | from .a2d import build as build_a2d
7 | from .jhmdb import build as build_jhmdb
8 | from .refexp import build as build_refexp
9 | from .concat_dataset import build as build_joint
10 | from .actionvos import build as build_actionvos
11 | # from .actionvos_allpos import build as build_actionvos_allpos
12 | # from .actionvos_state import build as build_state
13 |
14 | def get_coco_api_from_dataset(dataset):
15 | for _ in range(10):
16 | # if isinstance(dataset, torchvision.datasets.CocoDetection):
17 | # break
18 | if isinstance(dataset, torch.utils.data.Subset):
19 | dataset = dataset.dataset
20 | if isinstance(dataset, torchvision.datasets.CocoDetection):
21 | return dataset.coco
22 |
23 |
24 | def build_dataset(dataset_file: str, image_set: str, args):
25 | if dataset_file == 'ytvos':
26 | return build_ytvos(image_set, args)
27 | if dataset_file == 'davis':
28 | return build_davis(image_set, args)
29 | if dataset_file == 'a2d':
30 | return build_a2d(image_set, args)
31 | if dataset_file == 'jhmdb':
32 | return build_jhmdb(image_set, args)
33 | # for pretraining
34 | if dataset_file == "refcoco" or dataset_file == "refcoco+" or dataset_file == "refcocog":
35 | return build_refexp(dataset_file, image_set, args)
36 | # for joint training of refcoco and ytvos
37 | if dataset_file == 'joint':
38 | return build_joint(image_set, args)
39 | if dataset_file == 'actionvos':
40 | return build_actionvos(image_set, args)
41 | if dataset_file == 'actionvos_allpos':
42 | return build_actionvos_allpos(image_set, args)
43 | if dataset_file == 'vost':
44 | return build_state(image_set, args, 'vost')
45 | if dataset_file == 'vscos':
46 | return build_state(image_set, args, 'vscos')
47 | raise ValueError(f'dataset {dataset_file} not supported')
48 |
--------------------------------------------------------------------------------
/RF_ActionVOS/actionvos.py:
--------------------------------------------------------------------------------
1 | """
2 | actionvos data loader
3 | Note that we adjust the transform file (for data augmentation)
4 | # TODO check possible bug when box [0,0,0,0] goes to augmentation.
5 | """
6 | from pathlib import Path
7 |
8 | import torch
9 | from torch.autograd.grad_mode import F
10 | from torch.utils.data import Dataset
11 | import datasets.transforms_video_actionvos as T
12 |
13 | import os
14 | from PIL import Image
15 | import json
16 | import numpy as np
17 | import random
18 |
19 | #from datasets.categories import ytvos_category_dict as category_dict
20 |
21 |
22 | class ActionVOSDataset(Dataset):
23 | """
24 | In this version, sampling every