├── README.md
├── annotations
├── diving48_id2label.pkl
├── diving48_label2id.pkl
├── diving48_vocab.json
├── gym288_anno.pkl
├── gym288_id2label.pkl
├── gym288_label2id.pkl
├── gym99_anno.pkl
├── gym99_id2label.pkl
└── gym99_label2id.pkl
├── build_venv.sh
├── configs
├── Diving48_first_stage.yaml
├── Diving48_second_stage.yaml
├── Gym288_first_stage.yaml
├── Gym288_second_stage.yaml
├── Gym99_first_stage.yaml
└── Gym99_second_stage.yaml
├── data
├── data.md
└── dataloader.py
├── engine
└── engine.py
├── models
├── TQN.py
└── transformer.py
├── requirements.txt
├── scripts
├── construct_SUFB.py
├── test.py
├── train_1st_stage.py
└── train_2nd_stage.py
└── utils
├── augmentation.py
├── plot_utils.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # [Temporal Query Networks for Fine-grained Video Understanding](https://www.robots.ox.ac.uk/~vgg/research/tqn/)
2 |
3 | 📋 This repository contains the implementation of CVPR2021 paper [Temporal_Query_Networks for Fine-grained Video Understanding](https://arxiv.org/pdf/2104.09496.pdf)
4 |
5 | # Abstract
6 |
7 |
8 |
9 |
10 |
11 | Our objective in this work is fine-grained classification of actions in untrimmed videos, where the actions may be temporally extended or may span only a few frames of the video. We cast this into a query-response mechanism, where each query addresses a particular question, and has its own response label set.
12 |
13 | We make the following four contributions: (i) We propose a new model — a Temporal Query Network — which enables the query-response functionality, and a structural undertanding of fine-grained actions. It attends to relevant segments for each query with a temporal attention mechanism, and can be trained using only the labels for each query. (ii) We propose a new way — stochastic feature bank update — to train a network on videos of various lengths with the dense sampling required to respond to fine-grained queries. (iii) we compare the TQN to other architectures and text supervision methods, and analyze their pros and cons. Finally, (iv) we evaluate the method extensively on the FineGym and Diving48 benchmarks for fine-grained action classification and surpass the state-of-the-art using only RGB features.
14 |
15 | # Getting Started
16 | 1. Clone this repository
17 | ```
18 | git clone https://github.com/Chuhanxx/Temporal_Query_Networks.git
19 | ```
20 | 2. Create conda virtual env and install the requirements
21 | (This implementation requires CUDA and python > 3.7)
22 | ```
23 | cd Temporal_Query_Networks
24 | source build_venv.sh
25 | ```
26 |
27 | # Prepare Data and Weight Initialization
28 |
29 | Please refer to [data.md](https://github.com/Chuhanxx/Temporal_Query_Networks/blob/master/data/data.md) for data preparation.
30 |
31 |
32 | # Training
33 | you can start training the model with the following steps, taking the Diving48 dataset as an example,:
34 |
35 | 1. First stage training:
36 | Set the paths in the `Diving48_first_stage.yaml` config file first, and then run:
37 |
38 | ```
39 | cd scripts
40 | python train_1st_stage.py --name $EXP_NAME --dataset diving48 --dataset_config ../configs/Diving48_first_stage.yaml --gpus 0,1 --batch_size 16
41 | ```
42 | 2. Construct stochastically updated feature banks:
43 |
44 | ```
45 | python construct_SUFB.py --dataset diving48 --dataset_config ../configs/Diving48_first_stage.yaml \
46 | --gpus 0 --resume_file $PATH_TO_BEST_FILE_FROM_1ST_STAGE --out_dir $DIR_FOR_SAVING_FEATURES
47 | ```
48 |
49 | 3. Second stage training:
50 | Set the paths in the `Diving48_second_stage.yaml` config file first, and then run:
51 |
52 | ```
53 | python train_2nd_stage.py --name $EXP_NAME --dataset diving48 \
54 | --dataset_config ../configs/Diving48_second_stage.yaml \
55 | --batch_size 16 --gpus 0,1
56 | ```
57 |
58 | # Test
59 |
60 | ```
61 | python test.py --name $EXP_NAME --dataset diving48 --batch_size 1 \
62 | --dataset_config ../configs/Diving48_second_stage.yaml
63 | ```
64 |
65 | # Citation
66 |
67 | If you use this code etc., please cite the following paper:
68 |
69 | ```
70 | @inproceedings{zhangtqn,
71 | title={Temporal Query Networks for Fine-grained Video Understanding},
72 | author={Chuhan Zhang and Ankush Gputa and Andrew Zisserman},
73 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)},
74 | year={2021}
75 | }
76 | ```
77 |
78 | If you have any question, please contact czhang@robots.ox.ac.uk .
79 |
--------------------------------------------------------------------------------
/annotations/diving48_id2label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/diving48_id2label.pkl
--------------------------------------------------------------------------------
/annotations/diving48_label2id.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/diving48_label2id.pkl
--------------------------------------------------------------------------------
/annotations/diving48_vocab.json:
--------------------------------------------------------------------------------
1 | [
2 | [
3 | "Back",
4 | "15som",
5 | "05Twis",
6 | "FREE"
7 | ],
8 | [
9 | "Back",
10 | "15som",
11 | "15Twis",
12 | "FREE"
13 | ],
14 | [
15 | "Back",
16 | "15som",
17 | "25Twis",
18 | "FREE"
19 | ],
20 | [
21 | "Back",
22 | "15som",
23 | "NoTwis",
24 | "PIKE"
25 | ],
26 | [
27 | "Back",
28 | "15som",
29 | "NoTwis",
30 | "TUCK"
31 | ],
32 | [
33 | "Back",
34 | "25som",
35 | "15Twis",
36 | "PIKE"
37 | ],
38 | [
39 | "Back",
40 | "25som",
41 | "25Twis",
42 | "PIKE"
43 | ],
44 | [
45 | "Back",
46 | "25som",
47 | "NoTwis",
48 | "PIKE"
49 | ],
50 | [
51 | "Back",
52 | "25som",
53 | "NoTwis",
54 | "TUCK"
55 | ],
56 | [
57 | "Back",
58 | "2som",
59 | "15Twis",
60 | "FREE"
61 | ],
62 | [
63 | "Back",
64 | "2som",
65 | "25Twis",
66 | "FREE"
67 | ],
68 | [
69 | "Back",
70 | "35som",
71 | "NoTwis",
72 | "PIKE"
73 | ],
74 | [
75 | "Back",
76 | "35som",
77 | "NoTwis",
78 | "TUCK"
79 | ],
80 | [
81 | "Back",
82 | "3som",
83 | "NoTwis",
84 | "PIKE"
85 | ],
86 | [
87 | "Back",
88 | "3som",
89 | "NoTwis",
90 | "TUCK"
91 | ],
92 | [
93 | "Back",
94 | "Dive",
95 | "NoTwis",
96 | "PIKE"
97 | ],
98 | [
99 | "Back",
100 | "Dive",
101 | "NoTwis",
102 | "TUCK"
103 | ],
104 | [
105 | "Forward",
106 | "15som",
107 | "1Twis",
108 | "FREE"
109 | ],
110 | [
111 | "Forward",
112 | "15som",
113 | "2Twis",
114 | "FREE"
115 | ],
116 | [
117 | "Forward",
118 | "15som",
119 | "NoTwis",
120 | "PIKE"
121 | ],
122 | [
123 | "Forward",
124 | "1som",
125 | "NoTwis",
126 | "PIKE"
127 | ],
128 | [
129 | "Forward",
130 | "25som",
131 | "1Twis",
132 | "PIKE"
133 | ],
134 | [
135 | "Forward",
136 | "25som",
137 | "2Twis",
138 | "PIKE"
139 | ],
140 | [
141 | "Forward",
142 | "25som",
143 | "3Twis",
144 | "PIKE"
145 | ],
146 | [
147 | "Forward",
148 | "25som",
149 | "NoTwis",
150 | "PIKE"
151 | ],
152 | [
153 | "Forward",
154 | "25som",
155 | "NoTwis",
156 | "TUCK"
157 | ],
158 | [
159 | "Forward",
160 | "35som",
161 | "NoTwis",
162 | "PIKE"
163 | ],
164 | [
165 | "Forward",
166 | "35som",
167 | "NoTwis",
168 | "TUCK"
169 | ],
170 | [
171 | "Forward",
172 | "45som",
173 | "NoTwis",
174 | "TUCK"
175 | ],
176 | [
177 | "Forward",
178 | "Dive",
179 | "NoTwis",
180 | "PIKE"
181 | ],
182 | [
183 | "Forward",
184 | "Dive",
185 | "NoTwis",
186 | "STR"
187 | ],
188 | [
189 | "Inward",
190 | "15som",
191 | "NoTwis",
192 | "PIKE"
193 | ],
194 | [
195 | "Inward",
196 | "15som",
197 | "NoTwis",
198 | "TUCK"
199 | ],
200 | [
201 | "Inward",
202 | "25som",
203 | "NoTwis",
204 | "PIKE"
205 | ],
206 | [
207 | "Inward",
208 | "25som",
209 | "NoTwis",
210 | "TUCK"
211 | ],
212 | [
213 | "Inward",
214 | "35som",
215 | "NoTwis",
216 | "TUCK"
217 | ],
218 | [
219 | "Inward",
220 | "Dive",
221 | "NoTwis",
222 | "PIKE"
223 | ],
224 | [
225 | "Reverse",
226 | "15som",
227 | "05Twis",
228 | "FREE"
229 | ],
230 | [
231 | "Reverse",
232 | "15som",
233 | "15Twis",
234 | "FREE"
235 | ],
236 | [
237 | "Reverse",
238 | "15som",
239 | "25Twis",
240 | "FREE"
241 | ],
242 | [
243 | "Reverse",
244 | "15som",
245 | "35Twis",
246 | "FREE"
247 | ],
248 | [
249 | "Reverse",
250 | "15som",
251 | "NoTwis",
252 | "PIKE"
253 | ],
254 | [
255 | "Reverse",
256 | "25som",
257 | "15Twis",
258 | "PIKE"
259 | ],
260 | [
261 | "Reverse",
262 | "25som",
263 | "NoTwis",
264 | "PIKE"
265 | ],
266 | [
267 | "Reverse",
268 | "25som",
269 | "NoTwis",
270 | "TUCK"
271 | ],
272 | [
273 | "Reverse",
274 | "35som",
275 | "NoTwis",
276 | "TUCK"
277 | ],
278 | [
279 | "Reverse",
280 | "Dive",
281 | "NoTwis",
282 | "PIKE"
283 | ],
284 | [
285 | "Reverse",
286 | "Dive",
287 | "NoTwis",
288 | "TUCK"
289 | ]
290 | ]
--------------------------------------------------------------------------------
/annotations/gym288_anno.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_anno.pkl
--------------------------------------------------------------------------------
/annotations/gym288_id2label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_id2label.pkl
--------------------------------------------------------------------------------
/annotations/gym288_label2id.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_label2id.pkl
--------------------------------------------------------------------------------
/annotations/gym99_anno.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_anno.pkl
--------------------------------------------------------------------------------
/annotations/gym99_id2label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_id2label.pkl
--------------------------------------------------------------------------------
/annotations/gym99_label2id.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_label2id.pkl
--------------------------------------------------------------------------------
/build_venv.sh:
--------------------------------------------------------------------------------
1 | export CONDA_ENV_NAME=tqn
2 | echo $CONDA_ENV_NAME
3 |
4 | conda create -n $CONDA_ENV_NAME python=3.7
5 |
6 | eval "$(conda shell.bash hook)"
7 | conda activate $CONDA_ENV_NAME
8 |
9 | pip install -r requirements.txt
--------------------------------------------------------------------------------
/configs/Diving48_first_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "diving48",
3 | "num_classes": 48,
4 | "num_queries": 5,
5 | "attribute_set_size": 25,
6 | "max_length": 128,
7 | "downsample": 2,
8 | "root": "../diving48/",
9 | "save_folder": "../exps",
10 | "tbx_folder": "../tbx",
11 | "pretrained_weights_path": "../S3D_K400.pth.tar"
12 | }
13 |
--------------------------------------------------------------------------------
/configs/Diving48_second_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "diving48",
3 | "num_classes": 48,
4 | "num_queries": 5,
5 | "attribute_set_size": 25,
6 | "max_length": 100000000,
7 | "K": 10,
8 | "downsample": 2,
9 | "root": "../diving48/",
10 | "save_folder": "../exps",
11 | "tbx_folder": "../tbx",
12 | "feature_file": "../features/diving48_all_features.pkl",
13 | "pretrained_weights_path": ""
14 | }
15 |
--------------------------------------------------------------------------------
/configs/Gym288_first_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "gym288",
3 | "num_classes": 288,
4 | "num_queries": 13,
5 | "attribute_set_size": 98,
6 | "max_length": 48,
7 | "downsample": 1,
8 | "root": "../FineGym/",
9 | "save_folder": "../exps",
10 | "tbx_folder": "../tbx",
11 | "pretrained_weights_path": "../S3D_K400.pth.tar"
12 | }
13 |
--------------------------------------------------------------------------------
/configs/Gym288_second_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "gym288",
3 | "num_classes": 288,
4 | "num_queries": 13,
5 | "attribute_set_size": 98,
6 | "max_length": 100000000,
7 | "K": 6,
8 | "downsample": 1,
9 | "root": "../FineGym/",
10 | "save_folder": "../exps",
11 | "tbx_folder": "../tbx",
12 | "feature_file": "../features/gym288_all_features.pkl",
13 | "pretrained_weights_path": ""
14 | }
15 |
--------------------------------------------------------------------------------
/configs/Gym99_first_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "gym99",
3 | "num_classes": 99,
4 | "num_queries": 12,
5 | "attribute_set_size": 66,
6 | "max_length": 48,
7 | "downsample": 1,
8 | "root": "../FineGym/",
9 | "save_folder": "../exps",
10 | "tbx_folder": "../tbx",
11 | "pretrained_weights_path": "../S3D_K400.pth.tar"
12 | }
13 |
--------------------------------------------------------------------------------
/configs/Gym99_second_stage.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": "gym99",
3 | "num_classes": 99,
4 | "num_queries": 12,
5 | "attribute_set_size": 66,
6 | "max_length": 100000000,
7 | "K": 6,
8 | "downsample": 1,
9 | "root": "../FineGym/",
10 | "save_folder": "../exps",
11 | "tbx_folder": "../tbx",
12 | "feature_file": "../features/gym99_all_features.pkl",
13 | "pretrained_weights_path": ""
14 | }
15 |
--------------------------------------------------------------------------------
/data/data.md:
--------------------------------------------------------------------------------
1 | # Diving48
2 | Please download RGB data and annotations (cleaned Version 2 updated on 10/30/2020) from [the Diving48 webpage](http://www.svcl.ucsd.edu/projects/resound/dataset.html).
3 |
4 | After downloading the data to a `root` path named diving48, make sure that the folder tree looks like:
5 |
6 | diving48
7 | ├── frames
8 | │ └── OFxuiqI5G44_00247
9 | │ └── image_000001.jpg
10 | │ └── image_000002.jpg
11 | │ └── ......
12 | │ └── OFxuiqI5G44_00248
13 | │ └── image_000001.jpg
14 | │ └── image_000002.jpg
15 | │ └── .....
16 | ├── Diving48_V2_train.json
17 | └── Diving48_V2_test.json
18 |
19 |
20 | Then set the `root` path in the `configs/*.yaml` files to the path to your Diving48 folder.
21 |
22 |
23 | # FineGym
24 | The [official FineGym dataset webpage](https://sdolivia.github.io/FineGym/) provides the URL of original YouTube videos for downloading.
25 |
26 | The videos are of about 1 hours long, and need to be cropped into segments using the annotations provided. Due to the copyright concerns, we are not able to provide the cropped video segments/extracted frames for direct downloading. Please follow the instructions on the official webpage to conduct the pre-processing.
27 |
28 | After finish cropping the video segments and extracted the video frames, please create a `root` folder named `FineGym` and put the processed data into it, so that the folder tree looks like:
29 |
30 | FineGym
31 | ├── frames
32 | │ └── Z2T9B4qExzk_E_007618_007687_A_0020_0021
33 | │ └── image_000001.jpg
34 | │ └── image_000002.jpg
35 | │ └── image_000003.jpg
36 | │ └── ....
37 | │ └── zNL3kn3UBmg_E_008111_008200_A_0046_0048
38 | │ └── image_000001.jpg
39 | │ └── image_000002.jpg
40 | │ └── image_000003.jpg
41 | │ └── ....
42 | └── scripts
43 | └── gym99_train_element_v1.1.txt
44 | └── gym99_val_element.txt
45 | └── gym288_train_element_v1.1.txt
46 | └── gym288_val_element.txt
47 |
48 |
49 | Then set the `root` path in the `configs/*.yaml` files to the path to your FineGym folder.
50 |
51 | # Initialization of Weights
52 |
53 | S3D weights pretrained on Kinetics400 can be downloaded [here](https://www.robots.ox.ac.uk/~vgg/research/tqn/K400-weights/S3D_K400.pth.tar) (~30.3MB)
54 |
55 | Please set the `pretrained_weights_path` in the corresponding `configs/*_first_stage.yaml` files to the path to where the weights are saved.
56 |
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import os, sys, glob
2 |
3 | # import pickle
4 | import torch
5 | import random
6 | import math
7 | import time
8 | import json
9 |
10 | import torchvision
11 | import numpy as np
12 | import _pickle as cp
13 | import os.path as osp
14 |
15 | from PIL import Image
16 | from torch.nn.utils.rnn import pad_sequence
17 |
18 |
19 | class TQN_dataloader(object):
20 | def __init__(self, args, mode='train',transform=None,SUFB=False):
21 |
22 |
23 | self.root = args.root
24 | self.SUFB = SUFB
25 | self.mode = mode
26 | self.dataset = args.dataset
27 |
28 | self.transform = transform
29 | self.clip_len = args.clip_len
30 | self.downsample = args.downsample
31 | self.max_length = args.max_length
32 |
33 | self.label2id = cp.load(open(osp.join('../annotations/',args.dataset+'_label2id.pkl'),'rb'))
34 | self.id2label = cp.load(open(osp.join('../annotations/',args.dataset+'_id2label.pkl'),'rb'))
35 |
36 | if 'diving' in args.dataset:
37 |
38 | if mode=='train':
39 | self.gts = json.load(open(osp.join(self.root,'Diving48_V2_train.json'),'rb'))
40 | else:
41 | self.gts = json.load(open(osp.join(self.root,'Diving48_V2_test.json'),'rb'))
42 |
43 | self.vocab = json.load(open('../annotations/diving48_vocab.json','rb'))
44 |
45 | class_tokens = []
46 | for a in self.vocab:
47 | gt_class = torch.tensor([self.label2id[i] for i in a])
48 | class_tokens.append(gt_class)
49 |
50 | self.class_tokens = torch.stack(class_tokens,0)
51 |
52 | elif 'gym' in args.dataset:
53 |
54 | if mode =='train':
55 | self.gts = open(osp.join(self.root,'scripts',args.dataset+'_train_element_v1.1.txt'),'r').readlines()
56 | else:
57 | self.gts = open(osp.join(self.root,'scripts',args.dataset+'_val_element.txt'),'r').readlines()
58 |
59 | self.class_tokens = torch.stack([torch.tensor(i) for i in [*self.label2id.values()]],0)
60 |
61 | self.elements = self.preprocess(args.dataset)
62 |
63 | if self.SUFB:
64 | # Use the Stochastically Updated Feature Bank
65 | self.K = args.K
66 | self.vid2id = cp.load(open(args.feature_file.replace('features','vid2id'),'rb'))
67 |
68 |
69 | def __getitem__(self, index):
70 |
71 | gt = self.elements[index]
72 |
73 | if 'diving' in self.dataset:
74 | v_id = gt['vid_name']
75 | clabel = gt['label']
76 | frame_path = osp.join(self.root,'frames',v_id)
77 | total_frames = gt['end_frame'] - gt['start_frame']
78 | tokens = torch.tensor([self.label2id[i] for \
79 | i in self.vocab[clabel]])
80 |
81 | elif 'gym' in self.dataset:
82 | v_id,clabel,cname = gt
83 | frame_path = osp.join(self.root,'frames',v_id)
84 | total_frames = len(os.listdir(frame_path))
85 | tokens = torch.tensor(self.label2id[int(clabel)])
86 |
87 | downsample = self.set_downsample_rate(total_frames)
88 |
89 | if total_frames <=2:
90 | # skip broken samples
91 | return None,None,None,None
92 |
93 | elif self.mode != 'test':
94 | frames,ptr = self.sample_frames(total_frames,downsample)
95 | if len(frames) ==0:
96 | print(v_id,downsample,frames)
97 | seq = self.load_images(frame_path,frames)
98 |
99 | elif self.mode =='test':
100 | frames_list = self.sample_frames_test(total_frames,downsample)
101 |
102 | seq_list =[]
103 | for frames in frames_list:
104 | seq = self.load_images(frame_path,frames)
105 | seq_list.append(seq)
106 |
107 | # align and stack seqs in the lists
108 | min_chunks = min([s.shape[0] for s in seq_list])
109 | seq_list = [s[:min_chunks,:] for s in seq_list]
110 | seq = torch.stack(seq_list,dim=0)
111 |
112 | clabel = torch.tensor(int(clabel))
113 |
114 | if self.SUFB:
115 |
116 | v_id = self.vid2id[v_id]
117 | assert seq.shape[0] == self.K
118 | return v_id, seq, clabel, ptr, tokens
119 |
120 | return v_id, seq, clabel ,tokens
121 |
122 |
123 | def load_images(self,frame_path,frames):
124 | # load images and apply transformation
125 | seq_names = [os.path.join(frame_path, 'image_%06d.jpg' % (i+1)) for i in frames]
126 | seqs = [pil_loader(i) for i in seq_names]
127 | seqs = self.transform(seqs)
128 | seq = torch.stack(seqs, 1)
129 |
130 | C,T,H,W = seq.shape # [NUM_CLIPS, C, CLIP_LEN, H, W]
131 | seq = seq.view(C,-1,self.clip_len,H,W).transpose(0,1)
132 | return seq
133 |
134 |
135 |
136 | def sample_frames(self,total_frames,downsample):
137 | first_f = np.random.choice(np.arange(downsample+1))
138 | frames = np.arange(first_f,total_frames,downsample).tolist()
139 |
140 | if self.SUFB:
141 | # randomly choose a start point in the video to sample K clips
142 | n_clips = int(np.ceil(len(frames) / self.clip_len))
143 | ptr = np.random.choice(max(1,n_clips - self.K + 1))
144 | start = ptr * self.clip_len
145 | end = min([len(frames),(ptr + self.K) * self.clip_len])
146 | frames = frames[start:end]
147 |
148 | if self.mode == 'train':
149 | for _ in range(int(0.05*len(frames))+1):
150 | frames.remove(random.choice(frames))
151 |
152 | # pad the seq with the last frame to make the number of frames
153 | # sampled equal to K * clip_len,
154 | # where K in the number of clips computed online
155 | # in each iteration in the SUFB
156 | frames = self.pad_seq(frames)
157 |
158 | else:
159 | # temporal jittering
160 | if self.mode == 'train':
161 | for _ in range(int(0.01*total_frames) + 1):
162 | frames.remove(random.choice(frames))
163 |
164 | # pad the seq with the last frame if the number of frames
165 | # sampled is not divisiable by clip_len
166 | frames = self.pad_seq(frames)
167 | ptr = None
168 |
169 | return frames,ptr
170 |
171 | def sample_frames_test(self,total_frames,downsample):
172 | # temporal jittering for testing
173 | frames = list(np.arange(0,total_frames,downsample))
174 | frames0 = self.pad_seq(frames)
175 | frames1 = self.pad_seq(self.drop_frames(frames))
176 |
177 | return [frames0,frames1]
178 |
179 |
180 | def pad_seq(self,frames):
181 |
182 | if not isinstance(frames,list):
183 | frames = frames.tolist()
184 |
185 | if self.SUFB:
186 | diff_T = self.clip_len * self.K - len(frames)
187 | else:
188 | hanging_T = len(frames) % self.clip_len
189 | diff_T = 0
190 | if hanging_T !=0:
191 | diff_T = self.clip_len - hanging_T
192 |
193 | for i in range(diff_T):
194 | frames.append(frames[-1])
195 | return frames
196 |
197 |
198 | def preprocess(self,dataset):
199 | # Filter the videos by length for the 1st stage training
200 | elements= []
201 | if 'diving' in dataset:
202 | for gt in self.gts:
203 | v_id, clabel, start_frame, end_frame = [*gt.values()]
204 | num_frames = start_frame - end_frame
205 | if num_frames < self.max_length:
206 | elements.append(gt)
207 |
208 | elif 'gym' in dataset:
209 | self.dict = cp.load(open(osp.join('../annotations',self.dataset+'_anno.pkl'),'rb'))
210 | for gt in self.gts:
211 | v_id,clabel = gt.split(' ')
212 | num_frames = int(self.dict[v_id]['num_frames'])
213 | cname = self.dict[v_id]['cname']
214 | if num_frames < self.max_length:
215 | elements.append((v_id,clabel,cname))
216 |
217 | return elements
218 |
219 |
220 | def drop_frames(self,frames):
221 | total_frames = len(frames)
222 | new = frames.copy()
223 | for _ in range(int(0.02*total_frames)+1):
224 | new.remove(random.choice(new))
225 | return new
226 |
227 | def set_downsample_rate(self,total_frames):
228 | downsample = self.downsample
229 | while total_frames - downsample * self.clip_len < 1 and downsample > 1 :
230 | downsample -=1
231 | return downsample
232 |
233 | def __len__(self):
234 | return len(self.elements)
235 |
236 |
237 | def pil_loader(path):
238 | with open(path, 'rb') as f:
239 | with Image.open(f) as img:
240 | return img.convert('RGB')
241 |
242 |
243 | def SUFB_collate(batch):
244 |
245 | ids = [b[0] for b in batch if b[0] is not None]
246 | if ids ==[]:
247 | return None,None,None,None
248 | else:
249 | seqs = [b[1] for b in batch if b[1] is not None]
250 | labels = [b[2] for b in batch if b[2] is not None]
251 | tokens = [b[-1] for b in batch if b[-1] is not None]
252 |
253 | seqs = torch.stack(seqs,dim=0)
254 | labels=torch.stack(labels,dim=0)
255 | tokens = pad_sequence(tokens,batch_first =True)
256 |
257 | if len(batch)>3:
258 | # train or val mode
259 | ptrs = torch.tensor([b[3] for b in batch if b[3] is not None])
260 | return torch.tensor(ids),seqs,labels,ptrs,tokens
261 | else:
262 | # test mode
263 | return ids,seqs,labels,tokens
264 |
265 |
266 | def collate(batch):
267 | ids = [b[0] for b in batch if b[0] is not None]
268 | seq = [b[1] for b in batch if b[1] is not None]
269 | label = [b[2] for b in batch if b[2] is not None]
270 | tokens = [b[-1] for b in batch if b[-1] is not None]
271 |
272 | if len(seq) ==0:
273 | return None,None,None,None,None
274 | else:
275 | Ks = [s.shape[0] for s in seq]
276 | seq = pad_sequence(seq,batch_first=True)
277 | label=torch.stack(label,dim=0)
278 | tokens = pad_sequence(tokens,batch_first =True)
279 | return ids,seq,label,Ks,tokens
280 |
--------------------------------------------------------------------------------
/engine/engine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import random
4 | from tqdm import tqdm
5 | import numpy as np
6 | from utils.utils import *
7 | from utils.plot_utils import *
8 | from torch import nn
9 |
10 | def train_one_epoch(args,epoch,net,optimizer,trainset,train_loader,SUFB = False):
11 | np.random.seed(epoch)
12 | random.seed(epoch)
13 | net.train()
14 |
15 | data_time = AverageMeter()
16 | batch_time = AverageMeter()
17 | losses = [AverageMeter()]
18 | accuracy = [AverageMeter(),AverageMeter()]
19 | criterion = nn.CrossEntropyLoss(reduction='mean')
20 |
21 | t0 = time.time()
22 |
23 | for j, batch_samples in enumerate(train_loader):
24 | data_time.update(time.time() - t0)
25 |
26 |
27 | # cls_targets: action class labels
28 | # att_targets: attribute labels
29 | if not SUFB:
30 | v_ids, seq, cls_targets, n_clips_per_video, att_targets = batch_samples
31 | if seq is None:
32 | continue
33 | mask = tfm_mask(n_clips_per_video)
34 | preds,cls_preds = net((seq,mask))
35 | else:
36 | # ptrs: clip pointers, where the online sampled clips start
37 | v_ids, seq, cls_targets, ptrs, att_targets = batch_samples
38 | preds,cls_preds = net((seq,v_ids,ptrs))
39 |
40 | cls_targets = cls_targets.cuda()
41 | match_acc = multihead_acc(preds, cls_targets, att_targets, \
42 | trainset.class_tokens, Q = args.num_queries)
43 |
44 | preds = preds.reshape(-1, args.attribute_set_size)
45 | att_targets = att_targets.view(-1).cuda()
46 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0]
47 |
48 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)]
49 | cls_acc, match_acc = torch.cat(acc, 0).mean(0)
50 |
51 | loss = criterion(preds, att_targets)
52 | loss += criterion(cls_preds, cls_targets)
53 |
54 | accuracy[0].update(match_acc.item(), args.batch_size)
55 | accuracy[1].update(cls_acc.item(), args.batch_size)
56 | losses[0].update(loss.item(), args.batch_size)
57 |
58 | optimizer.zero_grad()
59 | loss.backward()
60 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.max_norm)
61 | optimizer.step()
62 |
63 | torch.cuda.empty_cache()
64 | batch_time.update(time.time() - t0)
65 | t0 = time.time()
66 |
67 | if j % (args.print_iter) == 0:
68 | t1 = time.time()
69 | print('Epoch: [{0}][{1}/{2}]\t'
70 | 'Loss {loss[0].val:.4f} Acc: {acc[0].val:.4f}\t'
71 | 'T-data:{dt.val:.2f} T-batch:{bt.val:.2f}\t'.format(
72 | epoch, j, len(train_loader),
73 | loss=losses, acc=accuracy, dt=data_time, bt=batch_time))
74 |
75 | args.train_plotter.add_data('local/loss', losses[0].local_avg, epoch*len(train_loader)+j)
76 | args.train_plotter.add_data('local/match_acc', accuracy[0].local_avg,epoch*len(train_loader)+j)
77 | args.train_plotter.add_data('local/cls_acc', accuracy[1].local_avg, epoch*len(train_loader)+j)
78 | torch.cuda.empty_cache()
79 |
80 | if epoch % args.save_epoch == 0:
81 | print('Saving state, epoch: %d iter:%d'%(epoch, j))
82 | save_ckpt(net,optimizer,args.best_acc,epoch,args.save_folder,str(epoch),SUFB)
83 |
84 | save_ckpt(net,optimizer,args.best_acc,epoch,args.save_folder,'latest',SUFB)
85 |
86 | train_acc = [i.avg for i in accuracy]
87 | args.train_plotter.add_data('global/loss', [i.avg for i in losses], epoch)
88 | args.train_plotter.add_data('global/match_acc', accuracy[0].local_avg, epoch)
89 | args.train_plotter.add_data('global/cls_acc', accuracy[1].local_avg, epoch)
90 |
91 |
92 |
93 |
94 | def eval_one_epoch(args,epoch,net,testset,test_loader,SUFB = False):
95 | net.eval()
96 | test_accuracy = [AverageMeter(),AverageMeter()]
97 | np.random.seed(epoch+1)
98 | random.seed(epoch+1)
99 |
100 | with torch.no_grad():
101 | for k, batch_samples in tqdm(enumerate(test_loader),total=len(test_loader)):
102 |
103 | # cls_targets: action class labels
104 | # att_targets: attribute labels
105 | if not SUFB:
106 | v_ids,seq,cls_targets,n_clips_per_video,att_targets = batch_samples
107 | if seq is None:
108 | continue
109 | mask = tfm_mask(n_clips_per_video)
110 | preds,cls_preds = net((seq,mask))
111 | else:
112 |
113 | # ptrs: clip pointers, where the online sampled clips start
114 | v_ids,seq,cls_targets,ptrs,att_targets = batch_samples
115 | preds,cls_preds = net((seq,v_ids,ptrs))
116 |
117 | cls_targets = cls_targets.cuda()
118 | match_acc = multihead_acc(preds,cls_targets, att_targets, \
119 | testset.class_tokens, Q=args.num_queries)
120 |
121 | preds = preds.reshape(-1,args.attribute_set_size)
122 | att_targets = att_targets.view(-1).cuda()
123 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0]
124 |
125 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)]
126 | cls_acc, match_acc = torch.cat(acc, 0).mean(0)
127 |
128 | test_accuracy[0].update(cls_acc.item(), args.batch_size)
129 | test_accuracy[1].update(match_acc.item(), args.batch_size)
130 |
131 | torch.cuda.empty_cache()
132 |
133 | test_acc = [i.avg for i in test_accuracy]
134 | args.val_plotter.add_data('global/cls_acc',test_acc[0], epoch)
135 | args.val_plotter.add_data('global/match_acc',test_acc[1], epoch)
136 |
137 |
138 | if test_acc[1] > args.best_acc:
139 | args.best_acc = test_acc[1]
140 | torch.save({'model_state_dict': net.state_dict(),\
141 | 'best_acc':test_acc[1]},\
142 | args.save_folder + '/' + 'best.pth')
143 |
144 |
145 |
146 | def save_ckpt(net,optimizer,best_acc,epoch,save_folder,name,SUFB):
147 | if SUFB:
148 | torch.save({'model_state_dict': net.state_dict(),
149 | 'optimizer_state_dict': optimizer.state_dict(),
150 | 'queue':net.module.queue,
151 | 'best_acc':best_acc,
152 | 'epoch':epoch},
153 | save_folder + '/' + name+'.pth')
154 |
155 | else:
156 | torch.save({'model_state_dict': net.state_dict(),
157 | 'optimizer_state_dict': optimizer.state_dict(),
158 | 'best_acc':best_acc,
159 | 'epoch':epoch},
160 | save_folder + '/' + name+'.pth')
161 |
--------------------------------------------------------------------------------
/models/TQN.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
2 | import torch.nn as nn
3 | import torch
4 | import math
5 | import numpy as np
6 | from torch.nn.utils.rnn import pad_sequence
7 | import torch.nn.functional as F
8 | from .transformer import *
9 | from utils.utils import tfm_mask
10 |
11 |
12 | class BasicConv3d(nn.Module):
13 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0,LayerNorm=False):
14 | super(BasicConv3d, self).__init__()
15 | self.conv = nn.Conv3d(in_planes, out_planes,
16 | kernel_size=kernel_size, stride=stride,
17 | padding=padding, bias=False)
18 |
19 | self.LayerNorm = LayerNorm
20 | if not self.LayerNorm:
21 | self.bn = nn.BatchNorm3d(out_planes)
22 | self.bn.weight.data.fill_(1)
23 | self.bn.bias.data.zero_()
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv.weight.data.fill_(1)
26 |
27 | def forward(self, x):
28 | x = self.conv(x)
29 | if not self.LayerNorm:
30 | x = self.bn(x)
31 | x = self.relu(x)
32 | return x
33 |
34 |
35 | class STConv3d(nn.Module):
36 | def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0):
37 | super(STConv3d, self).__init__()
38 | self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size),
39 | stride=(1,stride,stride),padding=(0,padding,padding), bias=False)
40 | self.conv2 = nn.Conv3d(out_planes,out_planes,kernel_size=(kernel_size,1,1),
41 | stride=(stride,1,1),padding=(padding,0,0), bias=False)
42 |
43 | self.bn1=nn.BatchNorm3d(out_planes)
44 | self.bn2=nn.BatchNorm3d(out_planes)
45 | self.relu = nn.ReLU(inplace=True)
46 |
47 | # init
48 | self.conv1.weight.data.normal_(mean=0, std=0.01)
49 | self.conv2.weight.data.normal_(mean=0, std=0.01)
50 |
51 | self.bn1.weight.data.fill_(1)
52 | self.bn1.bias.data.zero_()
53 | self.bn2.weight.data.fill_(1)
54 | self.bn2.bias.data.zero_()
55 |
56 | def forward(self,x):
57 | x=self.conv1(x)
58 | x=self.bn1(x)
59 | x=self.relu(x)
60 | x=self.conv2(x)
61 | x=self.bn2(x)
62 | x=self.relu(x)
63 | return x
64 |
65 |
66 | class SelfGating(nn.Module):
67 | def __init__(self, input_dim):
68 | super(SelfGating, self).__init__()
69 | self.fc = nn.Linear(input_dim, input_dim)
70 |
71 | def forward(self, input_tensor):
72 | """Feature gating as used in S3D-G"""
73 | spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4])
74 | weights = self.fc(spatiotemporal_average)
75 | weights = torch.sigmoid(weights)
76 | return weights[:, :, None, None, None] * input_tensor
77 |
78 |
79 | class SepInception(nn.Module):
80 | def __init__(self, in_planes, out_planes, gating=False,LayerNorm=False):
81 | super(SepInception, self).__init__()
82 |
83 | assert len(out_planes) == 6
84 | assert isinstance(out_planes, list)
85 |
86 | [num_out_0_0a,
87 | num_out_1_0a, num_out_1_0b,
88 | num_out_2_0a, num_out_2_0b,
89 | num_out_3_0b] = out_planes
90 |
91 | self.branch0 = nn.Sequential(
92 | BasicConv3d(in_planes, num_out_0_0a, kernel_size=1, stride=1),
93 | )
94 | self.branch1 = nn.Sequential(
95 | BasicConv3d(in_planes, num_out_1_0a, kernel_size=1, stride=1),
96 | STConv3d(num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1),
97 | )
98 | self.branch2 = nn.Sequential(
99 | BasicConv3d(in_planes, num_out_2_0a, kernel_size=1, stride=1),
100 | STConv3d(num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1),
101 | )
102 | self.branch3 = nn.Sequential(
103 | nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
104 | BasicConv3d(in_planes, num_out_3_0b, kernel_size=1, stride=1,LayerNorm=LayerNorm),
105 | )
106 |
107 | self.out_channels = sum([num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b])
108 |
109 | self.gating = gating
110 | if gating:
111 | self.gating_b0 = SelfGating(num_out_0_0a)
112 | self.gating_b1 = SelfGating(num_out_1_0b)
113 | self.gating_b2 = SelfGating(num_out_2_0b)
114 | self.gating_b3 = SelfGating(num_out_3_0b)
115 |
116 | def forward(self, x):
117 | if isinstance(x,tuple):
118 | x = x[0]
119 |
120 | x0 = self.branch0(x)
121 | x1 = self.branch1(x)
122 | x2 = self.branch2(x)
123 | x3 = self.branch3(x)
124 | if self.gating:
125 | x0 = self.gating_b0(x0)
126 | x1 = self.gating_b1(x1)
127 | x2 = self.gating_b2(x2)
128 | x3 = self.gating_b3(x3)
129 | out = torch.cat((x0, x1, x2, x3), 1)
130 | return out
131 |
132 |
133 |
134 |
135 | class TQN(nn.Module):
136 |
137 | def __init__(self, args,first_channel=3,features_out =False,gating=False,SUFB=False,mode='train'):
138 | super(TQN, self).__init__()
139 |
140 | self.gating = gating
141 | self.features_out = features_out
142 | self.d_model = args.d_model
143 | self.SUFB = SUFB
144 | self.mode = mode
145 |
146 | if SUFB:
147 | self.K =args.K
148 |
149 | ###################################
150 | '''S3D'''
151 | ###################################
152 |
153 | self.Conv_1a = STConv3d(first_channel, 64, kernel_size=7, stride=2, padding=3)
154 | self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112)
155 |
156 | self.MaxPool_2a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
157 | self.Conv_2b = BasicConv3d(64, 64, kernel_size=1, stride=1)
158 | self.Conv_2c = STConv3d(64, 192, kernel_size=3, stride=1, padding=1)
159 |
160 | self.block2 = nn.Sequential(
161 | self.MaxPool_2a, # (64, 32, 56, 56)
162 | self.Conv_2b, # (64, 32, 56, 56)
163 | self.Conv_2c) # (192, 32, 56, 56)
164 |
165 |
166 | self.MaxPool_3a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
167 | self.Mixed_3b = SepInception(in_planes=192, out_planes=[64, 96, 128, 16, 32, 32], gating=gating)
168 | self.Mixed_3c = SepInception(in_planes=256, out_planes=[128, 128, 192, 32, 96, 64], gating=gating)
169 |
170 | self.block3 = nn.Sequential(
171 | self.MaxPool_3a, # (192, 32 , 28, 28)
172 | self.Mixed_3b, # (256, 32, 28, 28)
173 | self.Mixed_3c) # (480, 32, 28, 28)
174 |
175 | self.MaxPool_4a = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
176 | self.Mixed_4b = SepInception(in_planes=480, out_planes=[192, 96, 208, 16, 48, 64], gating=gating)
177 | self.Mixed_4c = SepInception(in_planes=512, out_planes=[160, 112, 224, 24, 64, 64], gating=gating)
178 | self.Mixed_4d = SepInception(in_planes=512, out_planes=[128, 128, 256, 24, 64, 64], gating=gating)
179 | self.Mixed_4e = SepInception(in_planes=512, out_planes=[112, 144, 288, 32, 64, 64], gating=gating)
180 | self.Mixed_4f = SepInception(in_planes=528, out_planes=[256, 160, 320, 32, 128, 128], gating=gating)
181 |
182 | self.block4 = nn.Sequential(
183 | self.MaxPool_4a, # (480, 16, 14, 14)
184 | self.Mixed_4b, # (512, 16, 14, 14)
185 | self.Mixed_4c, # (512, 16, 14, 14)
186 | self.Mixed_4d, # (512, 16, 14, 14)
187 | self.Mixed_4e, # (528, 16, 14, 14)
188 | self.Mixed_4f) # (832, 16, 14, 14)
189 |
190 | self.MaxPool_5a = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
191 | self.Mixed_5b = SepInception(in_planes=832, out_planes=[256, 160, 320, 32, 128, 128], gating=gating)
192 | self.Mixed_5c = SepInception(in_planes=832, out_planes=[384, 192, 384, 48, 128, 128], gating=gating)
193 |
194 | self.block5 = nn.Sequential(
195 | self.MaxPool_5a, # (832, 8, 7, 7)
196 | self.Mixed_5b, # (832, 8, 7, 7)
197 | self.Mixed_5c) # (1024, 8, 7, 7)
198 |
199 | self.AvgPool_0a = nn.AvgPool3d(kernel_size=(1, 7, 7), stride=1)
200 |
201 |
202 |
203 | ###################################
204 | ''' Query Decoder'''
205 | ###################################
206 |
207 | if not self.features_out:
208 |
209 | # Decoder Layers
210 | self.H = args.H
211 | decoder_layer = TransformerDecoderLayer(self.d_model, args.H, 1024,
212 | 0.1, 'relu',normalize_before=True)
213 | decoder_norm = nn.LayerNorm(self.d_model)
214 | self.decoder = TransformerDecoder(decoder_layer, args.N, decoder_norm,
215 | return_intermediate=False)
216 |
217 | # Learnable Queries
218 | self.query_embed = nn.Embedding(args.num_queries,self.d_model)
219 | self.dropout_feas = nn.Dropout(args.dropout)
220 |
221 | # Attribute classifier
222 | self.classifier = nn.Linear(self.d_model,args.attribute_set_size)
223 |
224 | # Class classifier
225 | self.cls_classifier = nn.Linear(self.d_model,args.num_classes)
226 |
227 |
228 | self.apply(self._init_weights)
229 |
230 |
231 |
232 |
233 | def forward(self, input):
234 |
235 | ''' Reshape Input Sequences '''
236 | if not self.SUFB:
237 | x, mask = input
238 | if len(x.shape) ==5:
239 | # the First stage training
240 | BK, C, T, H, W =x.shape
241 | seg_per_video = mask.shape[-1] - mask.sum(1)
242 |
243 | else:
244 | # Feature extraction mode for full video sequence
245 | B, K, C, T, H, W = x.shape
246 | x = x.reshape(B*K,C,T,H,W)
247 | seg_per_video = None
248 |
249 | else:
250 | # Training with a Stochastically Updated Feature Bank
251 | x, vids, ptrs = input
252 | B, K, C, T, H, W = x.shape
253 | x = x.reshape(B*K,C,T,H,W)
254 | seg_per_video = None
255 |
256 |
257 | ''' Visual Backbone '''
258 | x = self.block1(x)
259 | x = self.block2(x)
260 | x = self.block3(x)
261 | x = self.block4(x)
262 | x = self.block5(x)
263 |
264 | features = self.AvgPool_0a(x).squeeze()
265 |
266 | if self.SUFB:
267 | features,Ts,mask = self.fill_SUFB(features,vids,ptrs)
268 |
269 | if self.features_out:
270 | return features
271 |
272 | else:
273 | ''' Query Decoder '''
274 | if seg_per_video is not None:
275 | # first stage training
276 | features = self.reshape_features(features.squeeze(),
277 | seg_per_video)
278 | B = len(seg_per_video)
279 | K = int(BK // B)
280 |
281 | elif not self.SUFB:
282 | features = features.reshape(B,K,-1)
283 |
284 | if mask is not None:
285 | mask = mask.view(B,-1)
286 |
287 | features = features.transpose(0,1)
288 | query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, B, 1)
289 | features = self.decoder(query_embed, features,
290 | memory_key_padding_mask=mask, pos=None, query_pos=None)
291 |
292 | out = self.dropout_feas(features) # [T,B,C]
293 | x= self.classifier(out[:-1]).transpose(0,1)
294 | x_cls = self.cls_classifier(out[-1])
295 |
296 | return x, x_cls
297 |
298 |
299 | def reshape_features(self,features,seg_per_video):
300 | reshaped_features = []
301 | counter = 0
302 | for n_seg in seg_per_video:
303 | reshaped_features.append(features[counter:counter+n_seg])
304 | counter += n_seg
305 | return pad_sequence(reshaped_features,batch_first=True)
306 |
307 |
308 | def fill_SUFB(self,features,vids,ptrs):
309 | fea_dim = features.shape[-1]
310 |
311 | if self.mode =='train':
312 | # Update newly computed features in the SUFB,
313 | # And read all the features from the SUFB
314 | full_features = []
315 | features = features.view(-1,self.K,fea_dim)
316 | features_split = torch.split(features, 1, dim=0)
317 |
318 | for f, vid, ptr in zip(features_split, vids, ptrs):
319 | vid = vid.item()
320 | end = min([len(self.queue[vid]), ptr + self.K])
321 |
322 | self.queue[vid][ptr:end] = f[0,:(end-ptr),:]
323 | full_features.append(self.queue[vid])
324 | self.queue[vid] = self.queue[vid].detach()
325 |
326 |
327 | Ts = [f.shape[0] for f in full_features]
328 | mask = tfm_mask(Ts).cuda()
329 | features = pad_sequence(full_features,batch_first=True).cuda()
330 |
331 |
332 | elif self.mode == 'test':
333 | # Test mode, compute all features online
334 | features = features.view(B,-1,fea_dim).cuda()
335 | Ts = [features[i].shape[0] for i in range(B)]
336 | mask = tfm_mask(Ts).cuda()
337 |
338 | return features,Ts,mask
339 |
340 |
341 | @staticmethod
342 | def _init_weights(module):
343 | r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0."""
344 |
345 | if isinstance(module, nn.Linear):
346 | module.weight.data.normal_(mean=0.0, std=0.02)
347 |
348 | elif isinstance(module, nn.MultiheadAttention):
349 | module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
350 | module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
351 |
352 | elif isinstance(module, nn.Embedding):
353 | module.weight.data.normal_(mean=0.0, std=0.02)
354 | if module.padding_idx is not None:
355 | module.weight.data[module.padding_idx].zero_()
356 |
357 |
358 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Code modified from DETR tranformer:
3 | https://github.com/facebookresearch/detr
4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
5 |
6 | """
7 |
8 | import copy
9 | from typing import Optional, List
10 | import pickle as cp
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import nn, Tensor
15 |
16 |
17 | class TransformerDecoder(nn.Module):
18 |
19 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
20 | super().__init__()
21 | self.layers = _get_clones(decoder_layer, num_layers)
22 | self.num_layers = num_layers
23 | self.norm = norm
24 | self.return_intermediate = return_intermediate
25 |
26 | def forward(self, tgt, memory,
27 | tgt_mask: Optional[Tensor] = None,
28 | memory_mask: Optional[Tensor] = None,
29 | tgt_key_padding_mask: Optional[Tensor] = None,
30 | memory_key_padding_mask: Optional[Tensor] = None,
31 | pos: Optional[Tensor] = None,
32 | query_pos: Optional[Tensor] = None):
33 | output = tgt
34 | T,B,C = memory.shape
35 | intermediate = []
36 |
37 | for n,layer in enumerate(self.layers):
38 |
39 | residual=True
40 | output,ws = layer(output, memory, tgt_mask=tgt_mask,
41 | memory_mask=memory_mask,
42 | tgt_key_padding_mask=tgt_key_padding_mask,
43 | memory_key_padding_mask=memory_key_padding_mask,
44 | pos=pos, query_pos=query_pos,residual=residual)
45 |
46 | if self.return_intermediate:
47 | intermediate.append(self.norm(output))
48 | if self.norm is not None:
49 | output = self.norm(output)
50 | if self.return_intermediate:
51 | intermediate.pop()
52 | intermediate.append(output)
53 |
54 | if self.return_intermediate:
55 | return torch.stack(intermediate)
56 | return output
57 |
58 |
59 |
60 | class TransformerDecoderLayer(nn.Module):
61 |
62 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
63 | activation="relu", normalize_before=False):
64 | super().__init__()
65 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
66 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
67 | # Implementation of Feedforward model
68 | self.linear1 = nn.Linear(d_model, dim_feedforward)
69 | self.dropout = nn.Dropout(dropout)
70 | self.linear2 = nn.Linear(dim_feedforward, d_model)
71 |
72 | self.norm1 = nn.LayerNorm(d_model)
73 | self.norm2 = nn.LayerNorm(d_model)
74 | self.norm3 = nn.LayerNorm(d_model)
75 | self.dropout1 = nn.Dropout(dropout)
76 | self.dropout2 = nn.Dropout(dropout)
77 | self.dropout3 = nn.Dropout(dropout)
78 |
79 | self.activation = _get_activation_fn(activation)
80 | self.normalize_before = normalize_before
81 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
82 | return tensor if pos is None else tensor + pos
83 |
84 | def forward_post(self, tgt, memory,
85 | tgt_mask: Optional[Tensor] = None,
86 | memory_mask: Optional[Tensor] = None,
87 | tgt_key_padding_mask: Optional[Tensor] = None,
88 | memory_key_padding_mask: Optional[Tensor] = None,
89 | pos: Optional[Tensor] = None,
90 | query_pos: Optional[Tensor] = None,
91 | residual=True):
92 | q = k = self.with_pos_embed(tgt, query_pos)
93 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
94 | key_padding_mask=tgt_key_padding_mask)
95 | tgt = self.norm1(tgt)
96 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
97 | key=self.with_pos_embed(memory, pos),
98 | value=memory, attn_mask=memory_mask,
99 | key_padding_mask=memory_key_padding_mask)
100 |
101 |
102 | # attn_weights [B,NUM_Q,T]
103 | tgt = tgt + self.dropout2(tgt2)
104 | tgt = self.norm2(tgt)
105 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
106 | tgt = tgt + self.dropout3(tgt2)
107 | tgt = self.norm3(tgt)
108 | return tgt,ws
109 |
110 | def forward_pre(self, tgt, memory,
111 | tgt_mask: Optional[Tensor] = None,
112 | memory_mask: Optional[Tensor] = None,
113 | tgt_key_padding_mask: Optional[Tensor] = None,
114 | memory_key_padding_mask: Optional[Tensor] = None,
115 | pos: Optional[Tensor] = None,
116 | query_pos: Optional[Tensor] = None):
117 | tgt2 = self.norm1(tgt)
118 | q = k = self.with_pos_embed(tgt2, query_pos)
119 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
120 | key_padding_mask=tgt_key_padding_mask)
121 | tgt = tgt + self.dropout1(tgt2)
122 | tgt2 = self.norm2(tgt)
123 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
124 | key=self.with_pos_embed(memory, pos),
125 | value=memory, attn_mask=memory_mask,
126 | key_padding_mask=memory_key_padding_mask)
127 | tgt = tgt + self.dropout2(tgt2)
128 | tgt2 = self.norm3(tgt)
129 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
130 | tgt = tgt + self.dropout3(tgt2)
131 | return tgt,ws
132 |
133 | def forward(self, tgt, memory,
134 | tgt_mask: Optional[Tensor] = None,
135 | memory_mask: Optional[Tensor] = None,
136 | tgt_key_padding_mask: Optional[Tensor] = None,
137 | memory_key_padding_mask: Optional[Tensor] = None,
138 | pos: Optional[Tensor] = None,
139 | query_pos: Optional[Tensor] = None,
140 | residual=True):
141 | if self.normalize_before:
142 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
143 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
144 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
145 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
146 |
147 |
148 | def _get_clones(module, N):
149 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
150 |
151 |
152 |
153 | def _get_activation_fn(activation):
154 | """Return an activation function given a string"""
155 | if activation == "relu":
156 | return F.relu
157 | if activation == "gelu":
158 | return F.gelu
159 | if activation == "glu":
160 | return F.glu
161 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipdb==0.13.7
2 | scipy==1.5.4
3 | six==1.16.0
4 | tensorboardX==2.2
5 | torch==1.8.1
6 | torchvision==0.9.1
7 | tqdm==4.60.0
8 |
--------------------------------------------------------------------------------
/scripts/construct_SUFB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | import time
5 | import torch
6 | import torchvision
7 | import sys
8 | sys.path.append('../')
9 | import numpy as np
10 | import _pickle as cp
11 | import os.path as osp
12 | import torch.nn as nn
13 | import utils.augmentation as A
14 | import torch.utils.data as data
15 | import torch.backends.cudnn as cudnn
16 | import json
17 | import glob
18 |
19 | from tqdm import tqdm
20 | from torchvision import transforms
21 | from models.TQN import TQN
22 |
23 | from data.dataloader import TQN_dataloader,SUFB_collate
24 |
25 |
26 |
27 |
28 |
29 | def worker_init_fn(worker_id):
30 | np.random.seed(np.random.get_state()[1][0] + worker_id)
31 |
32 | def main():
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--name', default='initial', type=str)
35 |
36 | ## data setting
37 | parser.add_argument('--dataset', default='gym99', type=str)
38 |
39 | parser.add_argument('--img_dim', default=224, type=int)
40 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block')
41 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate')
42 | parser.add_argument('--batch_size', default=1, type=int)
43 | parser.add_argument('--resume_file', default='', type=str)
44 | parser.add_argument('--d_model', default=1024, type=int)
45 | parser.add_argument('--dataset_config', default='', type=str)
46 |
47 | parser.add_argument('--all_frames', action='store_true')
48 | parser.add_argument('--seed', default=0, type=int)
49 | parser.add_argument('--out_dir', default='', type=str)
50 |
51 | # device params
52 | parser.add_argument("--gpus", dest="gpu", default="0", type=str)
53 | parser.add_argument('--num_workers', default=16, type=int)
54 |
55 | ## model setting
56 | parser.add_argument("--model",default='s3d',type=str,help='')
57 | parser.add_argument('--resume', default=-1, type=int)
58 | parser.add_argument('--dropout', default=0.2, type=float)
59 |
60 | ## frequency setting
61 | parser.add_argument('--eval_epoch', default=5, type=int)
62 | parser.add_argument('--max_iter', default=20000000, type=int)
63 |
64 |
65 | args = parser.parse_args()
66 | if args.dataset_config is not None:
67 | d = vars(args)
68 | with open(args.dataset_config, "r") as f:
69 | cfg = json.load(f)
70 | d.update(cfg)
71 |
72 | args.max_length = 1e6
73 | torch.manual_seed(args.seed)
74 | np.random.seed(args.seed)
75 | random.seed(args.seed)
76 |
77 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
78 | device = torch.device("cuda")
79 | torch.set_default_tensor_type('torch.FloatTensor')
80 |
81 | ## Set Up Model
82 |
83 | num_classes =int(''.join([s for s in args.dataset if s.isdigit()]))
84 | net = TQN(args,features_out=True)
85 | net = torch.nn.DataParallel(net).to(device)
86 |
87 | ## Load Model Weights
88 |
89 | assert args.resume_file!= ''
90 | checkpoint = torch.load(args.resume_file)
91 | state_dict = checkpoint['model_state_dict']
92 | net.load_state_dict(state_dict,strict=False)
93 |
94 | ## Set Up Dataloader
95 |
96 | transform = transforms.Compose([
97 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.6,p=0.8),
98 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len),
99 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25,
100 | p=1.0, consistent=False, clip_len=args.clip_len),
101 | A.ToTensor(),
102 | A.Normalize(args.dataset)])
103 | transform_test = transforms.Compose([
104 | A.CenterCrop(size=args.img_dim),
105 | A.ToTensor(),
106 | A.Normalize(args.dataset)])
107 |
108 | trainset = TQN_dataloader(args,transform=transform,mode='train')
109 | testset = TQN_dataloader(args,transform=transform_test,mode='val')
110 |
111 | for dataset in [trainset,testset]:
112 | data_loader = data.DataLoader(dataset, args.batch_size,num_workers=args.num_workers,
113 | collate_fn =SUFB_collate,pin_memory=True, worker_init_fn=worker_init_fn,drop_last=False)
114 |
115 |
116 | cudnn.benchmark = True
117 | net.eval()
118 |
119 | with torch.no_grad():
120 | for k, test_samples in tqdm(enumerate(data_loader),total=len(data_loader)):
121 |
122 | v_id, seq, target, _ = test_samples
123 | if v_id is None:
124 | continue
125 | B, K, C, T, H, W =seq.shape # [batch_size, num_clips, num_channels, clip_len, H, W]
126 | out_pkl = osp.join(args.out_dir,v_id[0]+'.pkl')
127 |
128 | if not osp.exists(osp.join(args.out_dir)):
129 | os.mkdir(osp.join(args.out_dir))
130 |
131 | # Clip super long videos to fit it in one/two gpus
132 | if seq.shape[-3] >600:
133 | seq = seq[:,:,int(0.2*K):-int(0.2*K):,:,:]
134 |
135 | # Forward
136 | feas = net((seq,None))
137 | feas = feas.squeeze().view(B,-1,feas.shape[-1])
138 |
139 | # Save individual feature files first
140 | with open(out_pkl, 'wb') as f:
141 | cp.dump(feas.cpu(),f)
142 |
143 |
144 | ## Write All the Feature Files into One File
145 | vid_to_id,features_dict = {}, {}
146 | pkls = glob.glob(osp.join(args.out_dir,'*.pkl'))
147 |
148 | for ind,pkl in enumerate(pkls):
149 | v_id = osp.basename(pkl).replace('.pkl','')
150 | vid_to_id[v_id] = ind
151 | features_dict[ind] = cp.load(open(pkl,'rb'))[0]
152 |
153 | with open(osp.join(args.out_dir,args.dataset+'_all_vid2id.pkl'), 'wb') as f:
154 | cp.dump(vid_to_id,f)
155 |
156 | with open(osp.join(args.out_dir,args.dataset+'_all_features.pkl'), 'wb') as f:
157 | cp.dump(features_dict,f)
158 |
159 | print('Saved featrues from ',len(features_dict),' video samples.')
160 |
161 | if __name__ == '__main__':
162 | main()
163 |
164 |
--------------------------------------------------------------------------------
/scripts/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import random
4 | import json
5 | import os
6 | import sys
7 | sys.path.append('../')
8 |
9 | from tqdm import tqdm
10 | from torch import nn
11 | import torch.optim as optim
12 | import torch.utils.data as data
13 | import torch.backends.cudnn as cudnn
14 | import utils.augmentation as A
15 | import os.path as osp
16 |
17 | from torch.utils.data import DataLoader
18 | from torchvision import transforms
19 |
20 | from models.TQN import *
21 | from utils.utils import make_dirs,multihead_acc,calc_topk_accuracy
22 | from utils.plot_utils import *
23 |
24 | from data.dataloader import TQN_dataloader,SUFB_collate
25 |
26 |
27 | def worker_init_fn(worker_id):
28 | np.random.seed(np.random.get_state()[1][0] + worker_id)
29 |
30 | def main():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--name', default='initial', type=str)
33 |
34 | ## data setting
35 | parser.add_argument('--dataset', default='', type=str)
36 | parser.add_argument('--img_dim', default=224, type=int)
37 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block')
38 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate')
39 | parser.add_argument('--batch_size', default=1, type=int)
40 | parser.add_argument('--root', default='', type=str)
41 | parser.add_argument('--dataset_config', default='', type=str)
42 | parser.add_argument('--feature_file', default='', type=str)
43 | parser.add_argument('--seed', default=0, type=int)
44 |
45 | # device params
46 | parser.add_argument("--gpus", dest="gpu", default="0", type=str)
47 | parser.add_argument('--num_workers', default=16, type=int)
48 |
49 | ## model setting
50 | parser.add_argument("--model",default='s3d',type=str,help='i3d,s3d')
51 | parser.add_argument('--dropout', default=0.5, type=float)
52 |
53 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder')
54 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder')
55 | parser.add_argument('--K', default=2, type=int,help='Number of clips updated per batch')
56 |
57 | parser.add_argument('--d_model', default=1024, type=int)
58 | parser.add_argument('--pretrained_weights_path', default='', type=str)
59 |
60 |
61 | args = parser.parse_args()
62 |
63 | if args.dataset_config is not None:
64 | d = vars(args)
65 | with open(args.dataset_config, "r") as f:
66 | cfg = json.load(f)
67 | d.update(cfg)
68 |
69 | assert args.batch_size == 1
70 |
71 | make_dirs(args)
72 | torch.manual_seed(args.seed)
73 | np.random.seed(args.seed)
74 | random.seed(args.seed)
75 |
76 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
77 | device = torch.device("cuda")
78 | torch.set_default_tensor_type('torch.FloatTensor')
79 |
80 | ## Set Up Model
81 |
82 | net = TQN(args).cuda()
83 | net = torch.nn.parallel.DataParallel(net)
84 |
85 | ## Load Model Weights
86 |
87 | resume_file = osp.join(args.save_folder,'best.pth')
88 | checkpoint = torch.load(resume_file)
89 | net.load_state_dict(checkpoint['model_state_dict'],strict=True)
90 |
91 | ## Set Up Dataloader
92 |
93 | transform_test = transforms.Compose([
94 | A.CenterCrop(size=args.img_dim),
95 | A.ToTensor(),
96 | A.Normalize(args.dataset)])
97 | testset = TQN_dataloader(args,mode='test',
98 | transform=transform_test,
99 | SUFB = False)
100 | test_loader = data.DataLoader(testset, args.batch_size,num_workers=args.num_workers,
101 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=False,
102 | collate_fn = SUFB_collate,drop_last=True,sampler = None)
103 |
104 |
105 | net.eval()
106 | test_accuracy = [AverageMeter(),AverageMeter()]
107 |
108 | with torch.no_grad():
109 |
110 | for k, test_samples in tqdm(enumerate(test_loader),total=len(test_loader)):
111 |
112 | v_ids, seqs, cls_targets, att_targets = test_samples
113 |
114 | seqs = seqs[0]
115 | B, K, C, T, H, W = seqs.shape
116 | cls_targets = cls_targets.cuda()
117 | att_targets = att_targets.view(-1).cuda()
118 |
119 | preds, cls_preds = net((seqs,None))
120 | preds = torch.softmax(preds, dim=-1).mean(0, keepdim=True)
121 |
122 | cls_preds = torch.softmax(cls_preds, dim=-1).mean(0, keepdim=True)
123 | match_acc = multihead_acc(preds, cls_targets, att_targets, \
124 | testset.class_tokens, Q = args.num_queries)
125 |
126 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0]
127 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)]
128 |
129 | cls_acc, match_acc = torch.cat(acc, 0).mean(0)
130 |
131 | test_accuracy[0].update(cls_acc.item(), 1)
132 | test_accuracy[1].update(match_acc.item(), 1)
133 |
134 |
135 | test_acc = [i.avg for i in test_accuracy]
136 | print("attribute_match_acc:%.2f"% round(test_acc[1]*100, 2))
137 | print("class_token_acc:%.2f" % round(test_acc[0]*100, 2))
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
143 |
144 |
--------------------------------------------------------------------------------
/scripts/train_1st_stage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import random
4 | import json
5 | import os
6 | import sys
7 | sys.path.append('../')
8 |
9 | from torch import nn
10 | import torch.optim as optim
11 | import torch.utils.data as data
12 | import torch.backends.cudnn as cudnn
13 | import utils.augmentation as A
14 | import os.path as osp
15 |
16 | from torch.utils.data import DataLoader
17 | from tensorboardX import SummaryWriter
18 | from torchvision import transforms
19 |
20 | from models.TQN import *
21 | from utils.utils import make_dirs
22 | from utils.plot_utils import *
23 |
24 | from engine.engine import train_one_epoch, eval_one_epoch
25 | from data.dataloader import TQN_dataloader,collate
26 |
27 | def worker_init_fn(worker_id):
28 | np.random.seed(np.random.get_state()[1][0] + worker_id)
29 |
30 | def main():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--name', default='initial', type=str)
33 |
34 | ## data setting
35 | parser.add_argument('--dataset', default='', type=str)
36 | parser.add_argument('--img_dim', default=224, type=int)
37 | parser.add_argument('--clip_len', default=8, type=int, help='Number of frames sampled for each clip')
38 | parser.add_argument('--downsample', default=2, type=int, help='Frame downsampling rate')
39 | parser.add_argument('--batch_size', default=24, type=int)
40 | parser.add_argument('--root', default='', type=str)
41 | parser.add_argument('--dataset_config', default='', type=str)
42 | parser.add_argument('--pretrained', default='k400', type=str)
43 | parser.add_argument('--seed', default=0, type=int)
44 |
45 | # device params
46 | parser.add_argument("--gpus", dest="gpu", default="0", type=str)
47 | parser.add_argument('--num_workers', default=8, type=int)
48 |
49 | ## model setting
50 | parser.add_argument("--model",default='s3d')
51 | parser.add_argument('--resume', default='', type=str)
52 | parser.add_argument('--dropout', default=0.8, type=float)
53 |
54 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder')
55 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder')
56 | parser.add_argument('--d_model', default=1024, type=int)
57 | parser.add_argument('--num_queries', default=0, type=int)
58 | parser.add_argument('--pretrained_weights_path', default='', type=str)
59 |
60 | ## optim setting
61 | parser.add_argument('--lr', default=0.001, type=float)
62 | parser.add_argument('--momentum', default=0.9, type=float)
63 | parser.add_argument('--weight_decay', default=1e-5, type=float)
64 | parser.add_argument('--optim', default='adam', type=str, help='sgd, adam, adadelta')
65 | parser.add_argument('--max_norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients')
66 | parser.add_argument('--max_epoches', default=100, type=int)
67 |
68 | parser.add_argument('--lr_steps', default=[10000, 200000], type=float, nargs="+",
69 | metavar='LRSteps', help='epochs to decay learning rate by 10')
70 | parser.add_argument('--best_acc', default=0, type=float)
71 |
72 | ## frequency setting
73 | parser.add_argument('--print_iter', default=5, type=int)
74 | parser.add_argument('--eval_epoch', default=1, type=int)
75 | parser.add_argument('--save_epoch', default=5, type=int)
76 | parser.add_argument('--save_folder', default='/users/czhang/data/FineGym/exps/github', type=str)
77 | parser.add_argument('--tbx_folder', default='/users/czhang/data/FineGym/tbx/github', type=str)
78 | parser.add_argument('--max_iter', default=20000000, type=int)
79 |
80 |
81 | args = parser.parse_args()
82 | if args.dataset_config is not None:
83 | d = vars(args)
84 | with open(args.dataset_config, "r") as f:
85 | cfg = json.load(f)
86 | d.update(cfg)
87 |
88 | make_dirs(args)
89 | torch.manual_seed(args.seed)
90 | np.random.seed(args.seed)
91 | random.seed(args.seed)
92 |
93 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
94 | device = torch.device("cuda")
95 | torch.set_default_tensor_type('torch.FloatTensor')
96 |
97 | ## Set Up Model
98 |
99 | net = TQN(args)
100 | net = torch.nn.DataParallel(net).to(device)
101 | num_param = sum(p.numel() for p in net.parameters())
102 |
103 | ## Load Model Weights
104 |
105 | if args.resume != '':
106 | # Resume from a checkpoint
107 | resume_file = osp.join(args.save_folder,str(args.resume)+'.pth')
108 | checkpoint = torch.load(resume_file)
109 | state_dict = checkpoint['model_state_dict']
110 | net.load_state_dict(state_dict,strict=True)
111 | args.best_acc = checkpoint['best_acc']
112 | resume_epoch = checkpoint['epoch']
113 |
114 | else:
115 | # load pretrained weights on K400
116 | state_dict = torch.load(args.pretrained_weights_path)
117 | new_dict = {}
118 | for k,v in state_dict.items():
119 | k = 'module.'+k
120 | new_dict[k] = v
121 | net.load_state_dict(new_dict,strict=False)
122 | resume_epoch = -1
123 |
124 |
125 | ## Set Up Dataloader
126 |
127 | transform = transforms.Compose([
128 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.7,p=0.8),
129 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len),
130 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25,
131 | p=1.0, consistent=False, clip_len=args.clip_len),
132 | A.ToTensor(),
133 | A.Normalize(dataset=args.dataset)])
134 |
135 | transform_test = transforms.Compose([
136 | A.CenterCrop(size=args.img_dim),
137 | A.ToTensor(),
138 | A.Normalize(dataset=args.dataset)])
139 |
140 | trainset = TQN_dataloader(args,
141 | transform=transform, mode='train',
142 | )
143 | testset = TQN_dataloader(args,
144 | transform=transform_test,mode='val',
145 | )
146 |
147 |
148 | train_loader = data.DataLoader(
149 | trainset, args.batch_size,num_workers=args.num_workers,
150 | pin_memory=True, worker_init_fn=worker_init_fn,shuffle =True,
151 | drop_last=True,collate_fn = collate)
152 | test_loader = data.DataLoader(
153 | testset, args.batch_size,num_workers=args.num_workers,
154 | pin_memory=True, worker_init_fn=worker_init_fn,
155 | collate_fn = collate,drop_last=True)
156 |
157 |
158 | ## Set Up Optimizer
159 |
160 | parameters = net.parameters()
161 | params = []
162 | for name, param in net.named_parameters():
163 | if 'attention' in name or 'decoder' in name :
164 | params.append({'params': param, 'lr':args.lr/10})
165 | else:
166 | params.append({'params': param, 'lr':args.lr})
167 |
168 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
169 | if args.resume != '':
170 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
171 |
172 | ## Set Up Tensorboard
173 |
174 | writer_val = SummaryWriter(logdir=osp.join(args.tbx_dir,'val'))
175 | writer_train = SummaryWriter(logdir=osp.join(args.tbx_dir, 'train'))
176 |
177 | args.val_plotter = PlotterThread(writer_val)
178 | args.train_plotter = PlotterThread(writer_train)
179 |
180 |
181 | ## Start Training
182 | cudnn.benchmark = True
183 |
184 | for epoch in range(args.max_epoches):
185 | if epoch <= resume_epoch:
186 | continue
187 | train_one_epoch(args,epoch,net,optimizer,trainset,train_loader)
188 |
189 | if epoch % args.eval_epoch == 0:
190 | eval_one_epoch(args,epoch,net,testset,test_loader)
191 |
192 |
193 |
194 | if __name__ == "__main__":
195 | main()
196 |
197 |
--------------------------------------------------------------------------------
/scripts/train_2nd_stage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import random
4 | import json
5 | import os
6 | import sys
7 | sys.path.append('../')
8 |
9 | from torch import nn
10 | import torch.optim as optim
11 | import torch.utils.data as data
12 | import torch.backends.cudnn as cudnn
13 | import utils.augmentation as A
14 | import os.path as osp
15 |
16 | from torch.utils.data import DataLoader
17 | from tensorboardX import SummaryWriter
18 | from torchvision import transforms
19 |
20 | from models.TQN import *
21 | from utils.utils import make_dirs
22 | from utils.plot_utils import *
23 |
24 | from engine.engine import train_one_epoch, eval_one_epoch
25 | from data.dataloader import TQN_dataloader,SUFB_collate
26 |
27 |
28 |
29 | def worker_init_fn(worker_id):
30 | np.random.seed(np.random.get_state()[1][0] + worker_id)
31 |
32 |
33 | def main():
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--name', default='initial', type=str)
36 |
37 | ## data setting
38 | parser.add_argument('--dataset', default='', type=str)
39 | parser.add_argument('--img_dim', default=224, type=int)
40 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block')
41 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate')
42 | parser.add_argument('--batch_size', default=24, type=int)
43 | parser.add_argument('--root', default='', type=str)
44 | parser.add_argument('--dataset_config', default='', type=str)
45 | parser.add_argument('--feature_file', default='', type=str)
46 | parser.add_argument('--seed', default=0, type=int)
47 |
48 | # device params
49 | parser.add_argument("--gpus", dest="gpu", default="0", type=str)
50 | parser.add_argument('--num_workers', default=16, type=int)
51 |
52 | ## model setting
53 | parser.add_argument("--model",default='s3d',type=str,help='i3d,s3d')
54 | parser.add_argument('--resume', default=-1, type=int)
55 | parser.add_argument('--dropout', default=0.8, type=float)
56 |
57 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder')
58 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder')
59 | parser.add_argument('--K', default=2, type=int,help='Number of clips updated per batch')
60 |
61 | parser.add_argument('--d_model', default=1024, type=int)
62 | parser.add_argument('--pretrained_weights_path', default='', type=str)
63 |
64 | ## optim setting
65 | parser.add_argument('--lr', default=0.001, type=float)
66 | parser.add_argument('--momentum', default=0.9, type=float)
67 | parser.add_argument('--weight_decay', default=1e-5, type=float)
68 | parser.add_argument('--optim', default='adam', type=str, help='sgd, adam, adadelta')
69 | parser.add_argument('--max_norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients')
70 | parser.add_argument('--max_epoches', default=1000000, type=int)
71 | parser.add_argument('--best_acc', default=0, type=float)
72 |
73 |
74 | ## frequency setting
75 | parser.add_argument('--print_iter', default=5, type=int)
76 | parser.add_argument('--eval_epoch', default=1, type=int)
77 | parser.add_argument('--save_epoch', default=5, type=int)
78 |
79 | parser.add_argument('--max_iter', default=20000000, type=int)
80 | parser.add_argument('--lr_steps', default=[10, 20], type=float, nargs="+",
81 | metavar='LRSteps', help='epochs to decay learning rate by 10')
82 |
83 |
84 | args = parser.parse_args()
85 | if args.dataset_config is not None:
86 | d = vars(args)
87 | with open(args.dataset_config, "r") as f:
88 | cfg = json.load(f)
89 | d.update(cfg)
90 |
91 | make_dirs(args)
92 | vid2id = cp.load(open(osp.join(args.root,args.feature_file).replace('features','vid2id'),'rb'))
93 |
94 | id2vid = {}
95 | for vid in vid2id.keys():
96 | id2vid[vid2id[vid]]=vid
97 |
98 | torch.manual_seed(args.seed)
99 | np.random.seed(args.seed)
100 | random.seed(args.seed)
101 |
102 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
103 | device = torch.device("cuda")
104 | torch.set_default_tensor_type('torch.FloatTensor')
105 |
106 | ## Set Up Model
107 |
108 | net = TQN(args,SUFB=True).cuda()
109 | net = torch.nn.parallel.DataParallel(net)
110 |
111 |
112 | ## Load Model Weights
113 |
114 | if args.resume != -1:
115 | resume_file = osp.join(args.save_folder,str(args.resume)+'.pth')
116 | checkpoint = torch.load(resume_file)
117 | state_dict = checkpoint['model_state_dict']
118 | net.load_state_dict(state_dict,strict=True)
119 | net.module.queue = checkpoint['queue']
120 | args.best_acc = checkpoint['best_acc']
121 | resume_epoch = checkpoint['epoch']
122 |
123 | elif args.pretrained_weights_path!='':
124 | checkpoint = torch.load(osp.join(args.pretrained_weights_path))
125 | net.load_state_dict(checkpoint['model_state_dict'],strict=True)
126 | net.module.queue= cp.load(open(args.feature_file,'rb'))
127 | resume_epoch = -1
128 | print('=== resumed from checkpoint:', args.pretrained_weights_path,'===')
129 |
130 |
131 | ## Set Up Dataloader
132 |
133 | transform = transforms.Compose([
134 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.6,p=0.8),
135 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len),
136 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25,
137 | p=1.0, consistent=False, clip_len=args.clip_len),
138 | A.ToTensor(),
139 | A.Normalize(args.dataset)])
140 | transform_test = transforms.Compose([
141 | A.CenterCrop(size=args.img_dim),
142 | A.ToTensor(),
143 | A.Normalize(args.dataset)])
144 |
145 | trainset = TQN_dataloader(args,
146 | transform=transform, mode='train',
147 | SUFB = True)
148 | testset = TQN_dataloader(args,mode='val',
149 | transform=transform_test,
150 | SUFB = True)
151 |
152 |
153 | train_loader = data.DataLoader(trainset, args.batch_size,num_workers=args.num_workers,
154 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=True,
155 | drop_last=True,collate_fn = SUFB_collate, sampler= None)
156 | test_loader = data.DataLoader(testset, args.batch_size,num_workers=args.num_workers,
157 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=False,
158 | collate_fn = SUFB_collate,drop_last=True,sampler = None)
159 |
160 | ## Set Up Optimizer
161 |
162 | parameters = net.parameters()
163 | params = []
164 | print('=> [optimizer] finetune TFM with smaller lr')
165 | for name, param in net.named_parameters():
166 | if ('attention' in name or 'decoder' in name) and int(args.resume)<10:
167 | params.append({'params': param, 'lr':args.lr/10})
168 | else:
169 | params.append({'params': param, 'lr':args.lr})
170 |
171 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
172 |
173 |
174 | ## Set Up Tensorboard
175 |
176 | writer_val = SummaryWriter(logdir=osp.join(args.tbx_dir,'val'))
177 | writer_train = SummaryWriter(logdir=osp.join(args.tbx_dir, 'train'))
178 |
179 | args.val_plotter = PlotterThread(writer_val)
180 | args.train_plotter = PlotterThread(writer_train)
181 |
182 |
183 | ## Start Training
184 |
185 | cudnn.benchmark = True
186 | net.train()
187 |
188 | for epoch in range(args.max_epoches):
189 | if epoch <= resume_epoch:
190 | continue
191 | adjust_learning_rate(args,optimizer, epoch)
192 | train_one_epoch(args,epoch,net,optimizer,trainset,train_loader,SUFB=True)
193 |
194 | if epoch % args.eval_epoch == 0:
195 | eval_one_epoch(args,epoch,net,testset,test_loader,SUFB=True)
196 |
197 |
198 |
199 |
200 | def adjust_learning_rate(args,optimizer, epoch):
201 | """Sets the learning rate to the initial LR decayed by 10 """
202 | epoch = epoch - 1
203 | decay = 0.1 ** (sum(epoch >= np.array(args.lr_steps)))
204 | lr = args.lr * decay
205 | decay = args.weight_decay
206 | print('current epoch:',epoch,'lr:',lr)
207 | if epoch >=10:
208 | for param_group in optimizer.param_groups:
209 | param_group['lr'] = lr
210 | param_group['weight_decay'] = decay
211 |
212 |
213 |
214 | if __name__ == '__main__':
215 | main()
216 |
217 |
--------------------------------------------------------------------------------
/utils/augmentation.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/Lextal/pspnet-pytorch
2 | import random
3 | import numbers
4 | import math
5 | import collections
6 | import torchvision
7 | import statistics
8 | from scipy.special import softmax
9 | from torchvision import transforms
10 | import torchvision.transforms.functional as F
11 | from collections import Counter
12 | from itertools import groupby
13 |
14 | from PIL import ImageOps, Image
15 | import numpy as np
16 | import pickle as cp
17 | import os.path as osp
18 |
19 | class Padding:
20 | def __init__(self, pad):
21 | self.pad = pad
22 |
23 | def __call__(self, img):
24 | return ImageOps.expand(img, border=self.pad, fill=0)
25 |
26 |
27 | class Scale:
28 | def __init__(self, size, interpolation=Image.BICUBIC):
29 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
30 | self.size = size
31 | self.interpolation = interpolation
32 |
33 | def __call__(self, imgmap):
34 | # assert len(imgmap) > 1 # list of images, last one is target (for segmentation tasks only)
35 | img1 = imgmap[0]
36 | if isinstance(self.size, int):
37 | w, h = img1.size
38 | if (w <= h and w == self.size) or (h <= w and h == self.size):
39 | return imgmap
40 | if w < h:
41 | ow = self.size
42 | oh = int(self.size * h / w)
43 | return [i.resize((ow, oh), self.interpolation) for i in imgmap]
44 | else:
45 | oh = self.size
46 | ow = int(self.size * w / h)
47 | return [i.resize((ow, oh), self.interpolation) for i in imgmap]
48 | else:
49 | return [i.resize(self.size, self.interpolation) for i in imgmap]
50 |
51 |
52 | class CenterCrop:
53 | def __init__(self, size, consistent=True):
54 | if isinstance(size, numbers.Number):
55 | self.size = (int(size), int(size))
56 | else:
57 | self.size = size
58 |
59 | def __call__(self, imgmap):
60 | img1 = imgmap[0]
61 | w, h = img1.size
62 | # imgmap = [i.resize((int(w*1.6),int(h*1.6))) for i in imgmap]
63 | # w, h = imgmap[0].size
64 | th, tw = self.size
65 | x1 = int(round((w - tw) / 2.))
66 | y1 = int(round((h - th) / 2.))
67 |
68 |
69 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap]
70 |
71 |
72 | class RandomSizedCrop:
73 | def __init__(self, size, interpolation=Image.BICUBIC, consistent=True, p=1.0, clip_len=0, h_ratio=0.7):
74 | self.size = size
75 | self.interpolation = interpolation
76 | self.consistent = consistent
77 | self.threshold = p
78 | self.clip_len = clip_len
79 | self.h_ratio = h_ratio
80 |
81 | def __call__(self, imgmap):
82 | img1 = imgmap[0]
83 | if random.random() < self.threshold: # do RandomSizedCrop
84 | for attempt in range(10):
85 | ori_w,ori_h = img1.size
86 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
87 | h = int(random.uniform(self.h_ratio, 1.0) * ori_h)
88 | w = int(h*aspect_ratio)
89 | if self.consistent:
90 | # if random.random() < 0.5:
91 | # w, h = h, w
92 | if w <= img1.size[0] and h <= img1.size[1]:
93 | mid_x = int(img1.size[0]//2)
94 | mid_h = int(img1.size[1]//2)
95 |
96 | # x1 = random.randint(int(mid_x-ori_w*0.15),int(mid_x+ori_w*0.15)) - w//2
97 | x1 = random.randint(0, img1.size[0] - w)
98 | y1 = random.randint(0, img1.size[1] - h)
99 |
100 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap]
101 | for i in imgmap: assert(i.size == (w, h))
102 |
103 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap]
104 | else:
105 | result = []
106 |
107 | if random.random() < 0.5:
108 | w, h = h, w
109 |
110 | for idx, i in enumerate(imgmap):
111 | if w <= img1.size[0] and h <= img1.size[1]:
112 | if idx % self.clip_len == 0:
113 | mid_x = int(img1.size[0]//2)
114 |
115 | x1 = random.randint(int(mid_x-ori_w*0.15),int(mid_x+ori_w*0.15)) - w//2
116 | y1 = random.randint(0, img1.size[1] - h)
117 |
118 | result.append(i.crop((x1, y1, x1 + w, y1 + h)))
119 | assert(result[-1].size == (w, h))
120 | else:
121 | result.append(i)
122 |
123 | assert len(result) == len(imgmap)
124 | return [i.resize((self.size, self.size), self.interpolation) for i in result]
125 |
126 | # Fallback
127 | scale = Scale(self.size, interpolation=self.interpolation)
128 | crop = CenterCrop(self.size)
129 | return crop(scale(imgmap))
130 | else: #don't do RandomSizedCrop, do CenterCrop
131 | crop = CenterCrop(self.size)
132 | return crop(imgmap)
133 |
134 |
135 | class RandomHorizontalFlip:
136 | def __init__(self, consistent=True, command=None, clip_len=0):
137 | self.consistent = consistent
138 | if command == 'left':
139 | self.threshold = 0
140 | elif command == 'right':
141 | self.threshold = 1
142 | else:
143 | self.threshold = 0.5
144 | self.clip_len = clip_len
145 | def __call__(self, imgmap):
146 | if self.consistent:
147 | if random.random() < self.threshold:
148 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap]
149 | else:
150 | return imgmap
151 | else:
152 | result = []
153 | for idx, i in enumerate(imgmap):
154 | if idx % self.clip_len == 0: th = random.random()
155 | if th < self.threshold:
156 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT))
157 | else:
158 | result.append(i)
159 | assert len(result) == len(imgmap)
160 | return result
161 |
162 |
163 |
164 |
165 | class ColorJitter(object):
166 | """Randomly change the brightness, contrast and saturation of an image.
167 | Args:
168 | brightness (float or tuple of float (min, max)): How much to jitter brightness.
169 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
170 | or the given [min, max]. Should be non negative numbers.
171 | contrast (float or tuple of float (min, max)): How much to jitter contrast.
172 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
173 | or the given [min, max]. Should be non negative numbers.
174 | saturation (float or tuple of float (min, max)): How much to jitter saturation.
175 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
176 | or the given [min, max]. Should be non negative numbers.
177 | hue (float or tuple of float (min, max)): How much to jitter hue.
178 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
179 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
180 | """
181 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0, clip_len=0):
182 | self.brightness = self._check_input(brightness, 'brightness')
183 | self.contrast = self._check_input(contrast, 'contrast')
184 | self.saturation = self._check_input(saturation, 'saturation')
185 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
186 | clip_first_on_zero=False)
187 | self.consistent = consistent
188 | self.threshold = p
189 | self.clip_len = clip_len
190 |
191 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
192 | if isinstance(value, numbers.Number):
193 | if value < 0:
194 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
195 | value = [center - value, center + value]
196 | if clip_first_on_zero:
197 | value[0] = max(value[0], 0)
198 | elif isinstance(value, (tuple, list)) and len(value) == 2:
199 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
200 | raise ValueError("{} values should be between {}".format(name, bound))
201 | else:
202 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
203 |
204 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
205 | # or (0., 0.) for hue, do nothing
206 | if value[0] == value[1] == center:
207 | value = None
208 | return value
209 |
210 | @staticmethod
211 | def get_params(brightness, contrast, saturation, hue):
212 | """Get a randomized transform to be applied on image.
213 | Arguments are same as that of __init__.
214 | Returns:
215 | Transform which randomly adjusts brightness, contrast and
216 | saturation in a random order.
217 | """
218 | transforms = []
219 |
220 | if brightness is not None:
221 | brightness_factor = random.uniform(brightness[0], brightness[1])
222 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
223 |
224 | if contrast is not None:
225 | contrast_factor = random.uniform(contrast[0], contrast[1])
226 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
227 |
228 | if saturation is not None:
229 | saturation_factor = random.uniform(saturation[0], saturation[1])
230 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
231 |
232 | if hue is not None:
233 | hue_factor = random.uniform(hue[0], hue[1])
234 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor)))
235 |
236 | random.shuffle(transforms)
237 | transform = torchvision.transforms.Compose(transforms)
238 |
239 |
240 | return transform
241 |
242 | def __call__(self, imgmap):
243 | if random.random() < self.threshold: # do ColorJitter
244 | if self.consistent:
245 | transform = self.get_params(self.brightness, self.contrast,
246 | self.saturation, self.hue)
247 | return [transform(i) for i in imgmap]
248 | else:
249 | if self.clip_len == 0:
250 | return [self.get_params(self.brightness, self.contrast, self.saturation, self.hue)(img) for img in imgmap]
251 | else:
252 | result = []
253 | for idx, img in enumerate(imgmap):
254 | if idx % self.clip_len == 0:
255 | transform = self.get_params(self.brightness, self.contrast,
256 | self.saturation, self.hue)
257 | result.append(transform(img))
258 | return result
259 |
260 | else: # don't do ColorJitter, do nothing
261 | return imgmap
262 |
263 | def __repr__(self):
264 | format_string = self.__class__.__name__ + '('
265 | format_string += 'brightness={0}'.format(self.brightness)
266 | format_string += ', contrast={0}'.format(self.contrast)
267 | format_string += ', saturation={0}'.format(self.saturation)
268 | format_string += ', hue={0})'.format(self.hue)
269 | return format_string
270 |
271 |
272 |
273 | class ToTensor:
274 | def __call__(self, imgmap):
275 | totensor = transforms.ToTensor()
276 | return [totensor(i) for i in imgmap]
277 |
278 | class ToPIL:
279 | def __call__(self, imgmap):
280 | topil = transforms.ToPILImage()
281 | return [topil(i) for i in imgmap]
282 |
283 | class Normalize:
284 | def __init__(self, dataset=None,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
285 |
286 | if 'diving' in dataset:
287 | self.mean = [0.3381, 0.5108, 0.5785]
288 | self.std = [0.2206, 0.2309, 0.2615]
289 | else:
290 | self.mean = mean
291 | self.std = std
292 |
293 | def __call__(self, imgmap):
294 | normalize = transforms.Normalize(mean=self.mean, std=self.std)
295 | return [normalize(i) for i in imgmap]
296 |
297 |
298 | def pil_loader(path):
299 | with open(path, 'rb') as f:
300 | with Image.open(f) as img:
301 | return img.convert('RGB')
--------------------------------------------------------------------------------
/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from collections import deque
4 | from threading import Thread
5 | from queue import Queue
6 |
7 | class Logger(object):
8 | '''write something to txt file'''
9 | def __init__(self, path):
10 | self.birth_time = datetime.now()
11 | filepath = os.path.join(path, self.birth_time.strftime('%Y-%m-%d-%H:%M:%S')+'.log')
12 | self.filepath = filepath
13 | with open(filepath, 'a') as f:
14 | f.write(self.birth_time.strftime('%Y-%m-%d %H:%M:%S')+'\n')
15 |
16 | def log(self, string):
17 | with open(self.filepath, 'a') as f:
18 | time_stamp = datetime.now() - self.birth_time
19 | f.write(strfdelta(time_stamp,"{d}-{h:02d}:{m:02d}:{s:02d}")+'\t'+string+'\n')
20 |
21 |
22 | class AverageMeter(object):
23 | """Computes and stores the average and current value"""
24 | def __init__(self, name='null', fmt=':.4f'):
25 | self.name = name
26 | self.fmt = fmt
27 | self.reset()
28 |
29 | def reset(self):
30 | self.val = 0
31 | self.avg = 0
32 | self.sum = 0
33 | self.count = 0
34 | self.local_history = deque([])
35 | self.local_avg = 0
36 | self.history = []
37 | self.dict = {} # save all data values here
38 | self.save_dict = {} # save mean and std here, for summary table
39 |
40 | def update(self, val, n=1, history=0, step=5):
41 | self.val = val
42 | self.sum += val * n
43 | self.count += n
44 | if n == 0: return
45 | self.avg = self.sum / self.count
46 | if history:
47 | self.history.append(val)
48 | if step > 0:
49 | self.local_history.append(val)
50 | if len(self.local_history) > step:
51 | self.local_history.popleft()
52 | self.local_avg = np.average(self.local_history)
53 |
54 |
55 | def dict_update(self, val, key):
56 | if key in self.dict.keys():
57 | self.dict[key].append(val)
58 | else:
59 | self.dict[key] = [val]
60 |
61 | def print_dict(self, title='IoU', save_data=False):
62 | """Print summary, clear self.dict and save mean+std in self.save_dict"""
63 | total = []
64 | for key in self.dict.keys():
65 | val = self.dict[key]
66 | avg_val = np.average(val)
67 | len_val = len(val)
68 | std_val = np.std(val)
69 |
70 | if key in self.save_dict.keys():
71 | self.save_dict[key].append([avg_val, std_val])
72 | else:
73 | self.save_dict[key] = [[avg_val, std_val]]
74 |
75 | print('Activity:%s, mean %s is %0.4f, std %s is %0.4f, length of data is %d' \
76 | % (key, title, avg_val, title, std_val, len_val))
77 |
78 | total.extend(val)
79 |
80 | self.dict = {}
81 | avg_total = np.average(total)
82 | len_total = len(total)
83 | std_total = np.std(total)
84 | print('\nOverall: mean %s is %0.4f, std %s is %0.4f, length of data is %d \n' \
85 | % (title, avg_total, title, std_total, len_total))
86 |
87 | if save_data:
88 | print('Save %s pickle file' % title)
89 | with open('img/%s.pickle' % title, 'wb') as f:
90 | pickle.dump(self.save_dict, f)
91 |
92 | def __len__(self):
93 | return self.count
94 |
95 | def __str__(self):
96 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
97 | return fmtstr.format(**self.__dict__)
98 |
99 |
100 |
101 |
102 | class PlotterThread():
103 | def __init__(self, writer):
104 | self.writer = writer
105 | self.task_queue = Queue(maxsize=0)
106 | worker = Thread(target=self.do_work, args=(self.task_queue,))
107 | worker.setDaemon(True)
108 | worker.start()
109 |
110 | def do_work(self, q):
111 | while True:
112 | content = q.get()
113 | if content[-1] == 'image':
114 | self.writer.add_image(*content[:-1])
115 | elif content[-1] == 'scalar':
116 | self.writer.add_scalar(*content[:-1])
117 | elif content[-1] == 'gif':
118 | self.writer.add_video(*content[:-1])
119 | else:
120 | raise ValueError
121 | q.task_done()
122 |
123 | def add_data(self, name, value, step, data_type='scalar'):
124 | self.task_queue.put([name, value, step, data_type])
125 |
126 | def __len__(self):
127 | return self.task_queue.qsize()
128 |
129 |
130 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import os.path as osp
4 |
5 | def tfm_mask(seg_per_video,temporal_mutliplier=1):
6 | """
7 | Attention mask for padded sequence in the Transformer
8 | True: not allowed to attend to
9 | """
10 | B = len(seg_per_video)
11 | L = max(seg_per_video) * temporal_mutliplier
12 | mask = torch.ones(B,L,dtype=torch.bool)
13 | for ind,l in enumerate(seg_per_video):
14 | mask[ind,:(l*temporal_mutliplier)] = False
15 |
16 | return mask
17 |
18 |
19 |
20 | def calc_topk_accuracy(output, target, topk=(1,)):
21 | """
22 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e
23 | Given predicted and ground truth labels,
24 | calculate top-k accuracies.
25 | """
26 | maxk = max(topk)
27 | batch_size = target.size(0)
28 |
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.view(1, -1).expand_as(pred))
32 |
33 | res = []
34 | for k in topk:
35 | correct_k = correct[:k].view(-1).float().sum(0)
36 | res.append(correct_k.mul_(1 / batch_size))
37 | return res
38 |
39 |
40 |
41 | def multihead_acc(preds,clabel,target,vocab,\
42 | Q=4,return_probs=False):
43 | """
44 | Args:
45 | preds: Predicted logits
46 | clabel: Class labels,
47 | List, [batch_size]
48 | target: Ground Truth attribute labels
49 | List, [batch_size,num_queries]
50 | vocab: The mapping between class index and attributes.
51 | List, [num_classes,num_queries]
52 | Q: Number of queries, Int
53 |
54 | Output:
55 | prob_acc: match predicted attibutes to ground-truth attibutes of N classes,
56 | class with the highest similarity is the predicted class.
57 | """
58 |
59 | # reshape the preds to (B,num_heads,num_classes)
60 | if len(preds.shape)==2:
61 | BQ,C = preds.shape
62 | B = BQ//Q
63 | preds = preds.view(-1,Q,C)
64 | elif len(preds.shape)==3:
65 | B,Q,C = preds.shape
66 |
67 | target = target.view(-1,Q)
68 | vocab_onehot = one_hot(vocab,C)
69 |
70 | cls_logits =torch.einsum('bhc,ahc->ba', preds, vocab_onehot.cuda())
71 | cls_pred = torch.argmax(cls_logits,dim=-1)
72 | prob_acc = (cls_pred == clabel).sum()*1.0 /B
73 |
74 | if return_probs:
75 | return prob_acc,cls_logits
76 | else:
77 | return prob_acc
78 |
79 |
80 |
81 | def one_hot(indices,depth):
82 | """
83 | make one hot vectors from indices
84 | """
85 | y = indices.unsqueeze(-1).long()
86 | y_onehot = torch.zeros(*indices.shape,depth)
87 | if indices.is_cuda:
88 | y_onehot = y_onehot.cuda()
89 | return y_onehot.scatter(-1,y,1)
90 |
91 |
92 |
93 | def make_dirs(args):
94 |
95 | if osp.exists(args.save_folder) == False:
96 | os.mkdir(args.save_folder)
97 | args.save_folder = osp.join(args.save_folder ,args.name)
98 | if osp.exists(args.save_folder) == False:
99 | os.mkdir(args.save_folder)
100 |
101 | args.tbx_dir =osp.join(args.tbx_folder,args.name)
102 | if osp.exists(args.tbx_folder) == False:
103 | os.mkdir(args.tbx_folder)
104 |
105 | if osp.exists(args.tbx_dir) == False:
106 | os.mkdir(args.tbx_dir)
107 |
108 | result_dir = osp.join(args.tbx_dir,'results')
109 | if osp.exists(result_dir) == False:
110 | os.mkdir(result_dir)
111 |
112 |
113 |
114 | def batch_denorm(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], channel=1):
115 | """
116 | De-normalization the images for viusalization
117 | """
118 | shape = [1]*tensor.dim(); shape[channel] = 3
119 | dtype = tensor.dtype
120 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device).view(shape)
121 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device).view(shape)
122 | output = tensor.mul(std).add(mean)
123 | return output
124 |
--------------------------------------------------------------------------------