├── README.md ├── annotations ├── diving48_id2label.pkl ├── diving48_label2id.pkl ├── diving48_vocab.json ├── gym288_anno.pkl ├── gym288_id2label.pkl ├── gym288_label2id.pkl ├── gym99_anno.pkl ├── gym99_id2label.pkl └── gym99_label2id.pkl ├── build_venv.sh ├── configs ├── Diving48_first_stage.yaml ├── Diving48_second_stage.yaml ├── Gym288_first_stage.yaml ├── Gym288_second_stage.yaml ├── Gym99_first_stage.yaml └── Gym99_second_stage.yaml ├── data ├── data.md └── dataloader.py ├── engine └── engine.py ├── models ├── TQN.py └── transformer.py ├── requirements.txt ├── scripts ├── construct_SUFB.py ├── test.py ├── train_1st_stage.py └── train_2nd_stage.py └── utils ├── augmentation.py ├── plot_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [Temporal Query Networks for Fine-grained Video Understanding](https://www.robots.ox.ac.uk/~vgg/research/tqn/) 2 | 3 | 📋 This repository contains the implementation of CVPR2021 paper [Temporal_Query_Networks for Fine-grained Video Understanding](https://arxiv.org/pdf/2104.09496.pdf) 4 | 5 | # Abstract 6 | 7 |

8 | 9 |

10 | 11 | Our objective in this work is fine-grained classification of actions in untrimmed videos, where the actions may be temporally extended or may span only a few frames of the video. We cast this into a query-response mechanism, where each query addresses a particular question, and has its own response label set. 12 | 13 | We make the following four contributions: (i) We propose a new model — a Temporal Query Network — which enables the query-response functionality, and a structural undertanding of fine-grained actions. It attends to relevant segments for each query with a temporal attention mechanism, and can be trained using only the labels for each query. (ii) We propose a new way — stochastic feature bank update — to train a network on videos of various lengths with the dense sampling required to respond to fine-grained queries. (iii) we compare the TQN to other architectures and text supervision methods, and analyze their pros and cons. Finally, (iv) we evaluate the method extensively on the FineGym and Diving48 benchmarks for fine-grained action classification and surpass the state-of-the-art using only RGB features. 14 | 15 | # Getting Started 16 | 1. Clone this repository 17 | ``` 18 | git clone https://github.com/Chuhanxx/Temporal_Query_Networks.git 19 | ``` 20 | 2. Create conda virtual env and install the requirements 21 | (This implementation requires CUDA and python > 3.7) 22 | ``` 23 | cd Temporal_Query_Networks 24 | source build_venv.sh 25 | ``` 26 | 27 | # Prepare Data and Weight Initialization 28 | 29 | Please refer to [data.md](https://github.com/Chuhanxx/Temporal_Query_Networks/blob/master/data/data.md) for data preparation. 30 | 31 | 32 | # Training 33 | you can start training the model with the following steps, taking the Diving48 dataset as an example,: 34 | 35 | 1. First stage training: 36 | Set the paths in the `Diving48_first_stage.yaml` config file first, and then run: 37 | 38 | ``` 39 | cd scripts 40 | python train_1st_stage.py --name $EXP_NAME --dataset diving48 --dataset_config ../configs/Diving48_first_stage.yaml --gpus 0,1 --batch_size 16 41 | ``` 42 | 2. Construct stochastically updated feature banks: 43 | 44 | ``` 45 | python construct_SUFB.py --dataset diving48 --dataset_config ../configs/Diving48_first_stage.yaml \ 46 | --gpus 0 --resume_file $PATH_TO_BEST_FILE_FROM_1ST_STAGE --out_dir $DIR_FOR_SAVING_FEATURES 47 | ``` 48 | 49 | 3. Second stage training: 50 | Set the paths in the `Diving48_second_stage.yaml` config file first, and then run: 51 | 52 | ``` 53 | python train_2nd_stage.py --name $EXP_NAME --dataset diving48 \ 54 | --dataset_config ../configs/Diving48_second_stage.yaml \ 55 | --batch_size 16 --gpus 0,1 56 | ``` 57 | 58 | # Test 59 | 60 | ``` 61 | python test.py --name $EXP_NAME --dataset diving48 --batch_size 1 \ 62 | --dataset_config ../configs/Diving48_second_stage.yaml 63 | ``` 64 | 65 | # Citation 66 | 67 | If you use this code etc., please cite the following paper: 68 | 69 | ``` 70 | @inproceedings{zhangtqn, 71 | title={Temporal Query Networks for Fine-grained Video Understanding}, 72 | author={Chuhan Zhang and Ankush Gputa and Andrew Zisserman}, 73 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)}, 74 | year={2021} 75 | } 76 | ``` 77 | 78 | If you have any question, please contact czhang@robots.ox.ac.uk . 79 | -------------------------------------------------------------------------------- /annotations/diving48_id2label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/diving48_id2label.pkl -------------------------------------------------------------------------------- /annotations/diving48_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/diving48_label2id.pkl -------------------------------------------------------------------------------- /annotations/diving48_vocab.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "Back", 4 | "15som", 5 | "05Twis", 6 | "FREE" 7 | ], 8 | [ 9 | "Back", 10 | "15som", 11 | "15Twis", 12 | "FREE" 13 | ], 14 | [ 15 | "Back", 16 | "15som", 17 | "25Twis", 18 | "FREE" 19 | ], 20 | [ 21 | "Back", 22 | "15som", 23 | "NoTwis", 24 | "PIKE" 25 | ], 26 | [ 27 | "Back", 28 | "15som", 29 | "NoTwis", 30 | "TUCK" 31 | ], 32 | [ 33 | "Back", 34 | "25som", 35 | "15Twis", 36 | "PIKE" 37 | ], 38 | [ 39 | "Back", 40 | "25som", 41 | "25Twis", 42 | "PIKE" 43 | ], 44 | [ 45 | "Back", 46 | "25som", 47 | "NoTwis", 48 | "PIKE" 49 | ], 50 | [ 51 | "Back", 52 | "25som", 53 | "NoTwis", 54 | "TUCK" 55 | ], 56 | [ 57 | "Back", 58 | "2som", 59 | "15Twis", 60 | "FREE" 61 | ], 62 | [ 63 | "Back", 64 | "2som", 65 | "25Twis", 66 | "FREE" 67 | ], 68 | [ 69 | "Back", 70 | "35som", 71 | "NoTwis", 72 | "PIKE" 73 | ], 74 | [ 75 | "Back", 76 | "35som", 77 | "NoTwis", 78 | "TUCK" 79 | ], 80 | [ 81 | "Back", 82 | "3som", 83 | "NoTwis", 84 | "PIKE" 85 | ], 86 | [ 87 | "Back", 88 | "3som", 89 | "NoTwis", 90 | "TUCK" 91 | ], 92 | [ 93 | "Back", 94 | "Dive", 95 | "NoTwis", 96 | "PIKE" 97 | ], 98 | [ 99 | "Back", 100 | "Dive", 101 | "NoTwis", 102 | "TUCK" 103 | ], 104 | [ 105 | "Forward", 106 | "15som", 107 | "1Twis", 108 | "FREE" 109 | ], 110 | [ 111 | "Forward", 112 | "15som", 113 | "2Twis", 114 | "FREE" 115 | ], 116 | [ 117 | "Forward", 118 | "15som", 119 | "NoTwis", 120 | "PIKE" 121 | ], 122 | [ 123 | "Forward", 124 | "1som", 125 | "NoTwis", 126 | "PIKE" 127 | ], 128 | [ 129 | "Forward", 130 | "25som", 131 | "1Twis", 132 | "PIKE" 133 | ], 134 | [ 135 | "Forward", 136 | "25som", 137 | "2Twis", 138 | "PIKE" 139 | ], 140 | [ 141 | "Forward", 142 | "25som", 143 | "3Twis", 144 | "PIKE" 145 | ], 146 | [ 147 | "Forward", 148 | "25som", 149 | "NoTwis", 150 | "PIKE" 151 | ], 152 | [ 153 | "Forward", 154 | "25som", 155 | "NoTwis", 156 | "TUCK" 157 | ], 158 | [ 159 | "Forward", 160 | "35som", 161 | "NoTwis", 162 | "PIKE" 163 | ], 164 | [ 165 | "Forward", 166 | "35som", 167 | "NoTwis", 168 | "TUCK" 169 | ], 170 | [ 171 | "Forward", 172 | "45som", 173 | "NoTwis", 174 | "TUCK" 175 | ], 176 | [ 177 | "Forward", 178 | "Dive", 179 | "NoTwis", 180 | "PIKE" 181 | ], 182 | [ 183 | "Forward", 184 | "Dive", 185 | "NoTwis", 186 | "STR" 187 | ], 188 | [ 189 | "Inward", 190 | "15som", 191 | "NoTwis", 192 | "PIKE" 193 | ], 194 | [ 195 | "Inward", 196 | "15som", 197 | "NoTwis", 198 | "TUCK" 199 | ], 200 | [ 201 | "Inward", 202 | "25som", 203 | "NoTwis", 204 | "PIKE" 205 | ], 206 | [ 207 | "Inward", 208 | "25som", 209 | "NoTwis", 210 | "TUCK" 211 | ], 212 | [ 213 | "Inward", 214 | "35som", 215 | "NoTwis", 216 | "TUCK" 217 | ], 218 | [ 219 | "Inward", 220 | "Dive", 221 | "NoTwis", 222 | "PIKE" 223 | ], 224 | [ 225 | "Reverse", 226 | "15som", 227 | "05Twis", 228 | "FREE" 229 | ], 230 | [ 231 | "Reverse", 232 | "15som", 233 | "15Twis", 234 | "FREE" 235 | ], 236 | [ 237 | "Reverse", 238 | "15som", 239 | "25Twis", 240 | "FREE" 241 | ], 242 | [ 243 | "Reverse", 244 | "15som", 245 | "35Twis", 246 | "FREE" 247 | ], 248 | [ 249 | "Reverse", 250 | "15som", 251 | "NoTwis", 252 | "PIKE" 253 | ], 254 | [ 255 | "Reverse", 256 | "25som", 257 | "15Twis", 258 | "PIKE" 259 | ], 260 | [ 261 | "Reverse", 262 | "25som", 263 | "NoTwis", 264 | "PIKE" 265 | ], 266 | [ 267 | "Reverse", 268 | "25som", 269 | "NoTwis", 270 | "TUCK" 271 | ], 272 | [ 273 | "Reverse", 274 | "35som", 275 | "NoTwis", 276 | "TUCK" 277 | ], 278 | [ 279 | "Reverse", 280 | "Dive", 281 | "NoTwis", 282 | "PIKE" 283 | ], 284 | [ 285 | "Reverse", 286 | "Dive", 287 | "NoTwis", 288 | "TUCK" 289 | ] 290 | ] -------------------------------------------------------------------------------- /annotations/gym288_anno.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_anno.pkl -------------------------------------------------------------------------------- /annotations/gym288_id2label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_id2label.pkl -------------------------------------------------------------------------------- /annotations/gym288_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym288_label2id.pkl -------------------------------------------------------------------------------- /annotations/gym99_anno.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_anno.pkl -------------------------------------------------------------------------------- /annotations/gym99_id2label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_id2label.pkl -------------------------------------------------------------------------------- /annotations/gym99_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chuhanxx/Temporal_Query_Networks/ac73e3375753463956f90037d59af5d3d50967e4/annotations/gym99_label2id.pkl -------------------------------------------------------------------------------- /build_venv.sh: -------------------------------------------------------------------------------- 1 | export CONDA_ENV_NAME=tqn 2 | echo $CONDA_ENV_NAME 3 | 4 | conda create -n $CONDA_ENV_NAME python=3.7 5 | 6 | eval "$(conda shell.bash hook)" 7 | conda activate $CONDA_ENV_NAME 8 | 9 | pip install -r requirements.txt -------------------------------------------------------------------------------- /configs/Diving48_first_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "diving48", 3 | "num_classes": 48, 4 | "num_queries": 5, 5 | "attribute_set_size": 25, 6 | "max_length": 128, 7 | "downsample": 2, 8 | "root": "../diving48/", 9 | "save_folder": "../exps", 10 | "tbx_folder": "../tbx", 11 | "pretrained_weights_path": "../S3D_K400.pth.tar" 12 | } 13 | -------------------------------------------------------------------------------- /configs/Diving48_second_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "diving48", 3 | "num_classes": 48, 4 | "num_queries": 5, 5 | "attribute_set_size": 25, 6 | "max_length": 100000000, 7 | "K": 10, 8 | "downsample": 2, 9 | "root": "../diving48/", 10 | "save_folder": "../exps", 11 | "tbx_folder": "../tbx", 12 | "feature_file": "../features/diving48_all_features.pkl", 13 | "pretrained_weights_path": "" 14 | } 15 | -------------------------------------------------------------------------------- /configs/Gym288_first_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "gym288", 3 | "num_classes": 288, 4 | "num_queries": 13, 5 | "attribute_set_size": 98, 6 | "max_length": 48, 7 | "downsample": 1, 8 | "root": "../FineGym/", 9 | "save_folder": "../exps", 10 | "tbx_folder": "../tbx", 11 | "pretrained_weights_path": "../S3D_K400.pth.tar" 12 | } 13 | -------------------------------------------------------------------------------- /configs/Gym288_second_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "gym288", 3 | "num_classes": 288, 4 | "num_queries": 13, 5 | "attribute_set_size": 98, 6 | "max_length": 100000000, 7 | "K": 6, 8 | "downsample": 1, 9 | "root": "../FineGym/", 10 | "save_folder": "../exps", 11 | "tbx_folder": "../tbx", 12 | "feature_file": "../features/gym288_all_features.pkl", 13 | "pretrained_weights_path": "" 14 | } 15 | -------------------------------------------------------------------------------- /configs/Gym99_first_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "gym99", 3 | "num_classes": 99, 4 | "num_queries": 12, 5 | "attribute_set_size": 66, 6 | "max_length": 48, 7 | "downsample": 1, 8 | "root": "../FineGym/", 9 | "save_folder": "../exps", 10 | "tbx_folder": "../tbx", 11 | "pretrained_weights_path": "../S3D_K400.pth.tar" 12 | } 13 | -------------------------------------------------------------------------------- /configs/Gym99_second_stage.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "gym99", 3 | "num_classes": 99, 4 | "num_queries": 12, 5 | "attribute_set_size": 66, 6 | "max_length": 100000000, 7 | "K": 6, 8 | "downsample": 1, 9 | "root": "../FineGym/", 10 | "save_folder": "../exps", 11 | "tbx_folder": "../tbx", 12 | "feature_file": "../features/gym99_all_features.pkl", 13 | "pretrained_weights_path": "" 14 | } 15 | -------------------------------------------------------------------------------- /data/data.md: -------------------------------------------------------------------------------- 1 | # Diving48 2 | Please download RGB data and annotations (cleaned Version 2 updated on 10/30/2020) from [the Diving48 webpage](http://www.svcl.ucsd.edu/projects/resound/dataset.html). 3 | 4 | After downloading the data to a `root` path named diving48, make sure that the folder tree looks like: 5 | 6 | diving48 7 | ├── frames 8 | │ └── OFxuiqI5G44_00247 9 | │ └── image_000001.jpg 10 | │ └── image_000002.jpg 11 | │ └── ...... 12 | │ └── OFxuiqI5G44_00248 13 | │ └── image_000001.jpg 14 | │ └── image_000002.jpg 15 | │ └── ..... 16 | ├── Diving48_V2_train.json 17 | └── Diving48_V2_test.json 18 | 19 | 20 | Then set the `root` path in the `configs/*.yaml` files to the path to your Diving48 folder. 21 | 22 | 23 | # FineGym 24 | The [official FineGym dataset webpage](https://sdolivia.github.io/FineGym/) provides the URL of original YouTube videos for downloading. 25 | 26 | The videos are of about 1 hours long, and need to be cropped into segments using the annotations provided. Due to the copyright concerns, we are not able to provide the cropped video segments/extracted frames for direct downloading. Please follow the instructions on the official webpage to conduct the pre-processing. 27 | 28 | After finish cropping the video segments and extracted the video frames, please create a `root` folder named `FineGym` and put the processed data into it, so that the folder tree looks like: 29 | 30 | FineGym 31 | ├── frames 32 | │ └── Z2T9B4qExzk_E_007618_007687_A_0020_0021 33 | │ └── image_000001.jpg 34 | │ └── image_000002.jpg 35 | │ └── image_000003.jpg 36 | │ └── .... 37 | │ └── zNL3kn3UBmg_E_008111_008200_A_0046_0048 38 | │ └── image_000001.jpg 39 | │ └── image_000002.jpg 40 | │ └── image_000003.jpg 41 | │ └── .... 42 | └── scripts 43 | └── gym99_train_element_v1.1.txt 44 | └── gym99_val_element.txt 45 | └── gym288_train_element_v1.1.txt 46 | └── gym288_val_element.txt 47 | 48 | 49 | Then set the `root` path in the `configs/*.yaml` files to the path to your FineGym folder. 50 | 51 | # Initialization of Weights 52 | 53 | S3D weights pretrained on Kinetics400 can be downloaded [here](https://www.robots.ox.ac.uk/~vgg/research/tqn/K400-weights/S3D_K400.pth.tar) (~30.3MB) 54 | 55 | Please set the `pretrained_weights_path` in the corresponding `configs/*_first_stage.yaml` files to the path to where the weights are saved. 56 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | 3 | # import pickle 4 | import torch 5 | import random 6 | import math 7 | import time 8 | import json 9 | 10 | import torchvision 11 | import numpy as np 12 | import _pickle as cp 13 | import os.path as osp 14 | 15 | from PIL import Image 16 | from torch.nn.utils.rnn import pad_sequence 17 | 18 | 19 | class TQN_dataloader(object): 20 | def __init__(self, args, mode='train',transform=None,SUFB=False): 21 | 22 | 23 | self.root = args.root 24 | self.SUFB = SUFB 25 | self.mode = mode 26 | self.dataset = args.dataset 27 | 28 | self.transform = transform 29 | self.clip_len = args.clip_len 30 | self.downsample = args.downsample 31 | self.max_length = args.max_length 32 | 33 | self.label2id = cp.load(open(osp.join('../annotations/',args.dataset+'_label2id.pkl'),'rb')) 34 | self.id2label = cp.load(open(osp.join('../annotations/',args.dataset+'_id2label.pkl'),'rb')) 35 | 36 | if 'diving' in args.dataset: 37 | 38 | if mode=='train': 39 | self.gts = json.load(open(osp.join(self.root,'Diving48_V2_train.json'),'rb')) 40 | else: 41 | self.gts = json.load(open(osp.join(self.root,'Diving48_V2_test.json'),'rb')) 42 | 43 | self.vocab = json.load(open('../annotations/diving48_vocab.json','rb')) 44 | 45 | class_tokens = [] 46 | for a in self.vocab: 47 | gt_class = torch.tensor([self.label2id[i] for i in a]) 48 | class_tokens.append(gt_class) 49 | 50 | self.class_tokens = torch.stack(class_tokens,0) 51 | 52 | elif 'gym' in args.dataset: 53 | 54 | if mode =='train': 55 | self.gts = open(osp.join(self.root,'scripts',args.dataset+'_train_element_v1.1.txt'),'r').readlines() 56 | else: 57 | self.gts = open(osp.join(self.root,'scripts',args.dataset+'_val_element.txt'),'r').readlines() 58 | 59 | self.class_tokens = torch.stack([torch.tensor(i) for i in [*self.label2id.values()]],0) 60 | 61 | self.elements = self.preprocess(args.dataset) 62 | 63 | if self.SUFB: 64 | # Use the Stochastically Updated Feature Bank 65 | self.K = args.K 66 | self.vid2id = cp.load(open(args.feature_file.replace('features','vid2id'),'rb')) 67 | 68 | 69 | def __getitem__(self, index): 70 | 71 | gt = self.elements[index] 72 | 73 | if 'diving' in self.dataset: 74 | v_id = gt['vid_name'] 75 | clabel = gt['label'] 76 | frame_path = osp.join(self.root,'frames',v_id) 77 | total_frames = gt['end_frame'] - gt['start_frame'] 78 | tokens = torch.tensor([self.label2id[i] for \ 79 | i in self.vocab[clabel]]) 80 | 81 | elif 'gym' in self.dataset: 82 | v_id,clabel,cname = gt 83 | frame_path = osp.join(self.root,'frames',v_id) 84 | total_frames = len(os.listdir(frame_path)) 85 | tokens = torch.tensor(self.label2id[int(clabel)]) 86 | 87 | downsample = self.set_downsample_rate(total_frames) 88 | 89 | if total_frames <=2: 90 | # skip broken samples 91 | return None,None,None,None 92 | 93 | elif self.mode != 'test': 94 | frames,ptr = self.sample_frames(total_frames,downsample) 95 | if len(frames) ==0: 96 | print(v_id,downsample,frames) 97 | seq = self.load_images(frame_path,frames) 98 | 99 | elif self.mode =='test': 100 | frames_list = self.sample_frames_test(total_frames,downsample) 101 | 102 | seq_list =[] 103 | for frames in frames_list: 104 | seq = self.load_images(frame_path,frames) 105 | seq_list.append(seq) 106 | 107 | # align and stack seqs in the lists 108 | min_chunks = min([s.shape[0] for s in seq_list]) 109 | seq_list = [s[:min_chunks,:] for s in seq_list] 110 | seq = torch.stack(seq_list,dim=0) 111 | 112 | clabel = torch.tensor(int(clabel)) 113 | 114 | if self.SUFB: 115 | 116 | v_id = self.vid2id[v_id] 117 | assert seq.shape[0] == self.K 118 | return v_id, seq, clabel, ptr, tokens 119 | 120 | return v_id, seq, clabel ,tokens 121 | 122 | 123 | def load_images(self,frame_path,frames): 124 | # load images and apply transformation 125 | seq_names = [os.path.join(frame_path, 'image_%06d.jpg' % (i+1)) for i in frames] 126 | seqs = [pil_loader(i) for i in seq_names] 127 | seqs = self.transform(seqs) 128 | seq = torch.stack(seqs, 1) 129 | 130 | C,T,H,W = seq.shape # [NUM_CLIPS, C, CLIP_LEN, H, W] 131 | seq = seq.view(C,-1,self.clip_len,H,W).transpose(0,1) 132 | return seq 133 | 134 | 135 | 136 | def sample_frames(self,total_frames,downsample): 137 | first_f = np.random.choice(np.arange(downsample+1)) 138 | frames = np.arange(first_f,total_frames,downsample).tolist() 139 | 140 | if self.SUFB: 141 | # randomly choose a start point in the video to sample K clips 142 | n_clips = int(np.ceil(len(frames) / self.clip_len)) 143 | ptr = np.random.choice(max(1,n_clips - self.K + 1)) 144 | start = ptr * self.clip_len 145 | end = min([len(frames),(ptr + self.K) * self.clip_len]) 146 | frames = frames[start:end] 147 | 148 | if self.mode == 'train': 149 | for _ in range(int(0.05*len(frames))+1): 150 | frames.remove(random.choice(frames)) 151 | 152 | # pad the seq with the last frame to make the number of frames 153 | # sampled equal to K * clip_len, 154 | # where K in the number of clips computed online 155 | # in each iteration in the SUFB 156 | frames = self.pad_seq(frames) 157 | 158 | else: 159 | # temporal jittering 160 | if self.mode == 'train': 161 | for _ in range(int(0.01*total_frames) + 1): 162 | frames.remove(random.choice(frames)) 163 | 164 | # pad the seq with the last frame if the number of frames 165 | # sampled is not divisiable by clip_len 166 | frames = self.pad_seq(frames) 167 | ptr = None 168 | 169 | return frames,ptr 170 | 171 | def sample_frames_test(self,total_frames,downsample): 172 | # temporal jittering for testing 173 | frames = list(np.arange(0,total_frames,downsample)) 174 | frames0 = self.pad_seq(frames) 175 | frames1 = self.pad_seq(self.drop_frames(frames)) 176 | 177 | return [frames0,frames1] 178 | 179 | 180 | def pad_seq(self,frames): 181 | 182 | if not isinstance(frames,list): 183 | frames = frames.tolist() 184 | 185 | if self.SUFB: 186 | diff_T = self.clip_len * self.K - len(frames) 187 | else: 188 | hanging_T = len(frames) % self.clip_len 189 | diff_T = 0 190 | if hanging_T !=0: 191 | diff_T = self.clip_len - hanging_T 192 | 193 | for i in range(diff_T): 194 | frames.append(frames[-1]) 195 | return frames 196 | 197 | 198 | def preprocess(self,dataset): 199 | # Filter the videos by length for the 1st stage training 200 | elements= [] 201 | if 'diving' in dataset: 202 | for gt in self.gts: 203 | v_id, clabel, start_frame, end_frame = [*gt.values()] 204 | num_frames = start_frame - end_frame 205 | if num_frames < self.max_length: 206 | elements.append(gt) 207 | 208 | elif 'gym' in dataset: 209 | self.dict = cp.load(open(osp.join('../annotations',self.dataset+'_anno.pkl'),'rb')) 210 | for gt in self.gts: 211 | v_id,clabel = gt.split(' ') 212 | num_frames = int(self.dict[v_id]['num_frames']) 213 | cname = self.dict[v_id]['cname'] 214 | if num_frames < self.max_length: 215 | elements.append((v_id,clabel,cname)) 216 | 217 | return elements 218 | 219 | 220 | def drop_frames(self,frames): 221 | total_frames = len(frames) 222 | new = frames.copy() 223 | for _ in range(int(0.02*total_frames)+1): 224 | new.remove(random.choice(new)) 225 | return new 226 | 227 | def set_downsample_rate(self,total_frames): 228 | downsample = self.downsample 229 | while total_frames - downsample * self.clip_len < 1 and downsample > 1 : 230 | downsample -=1 231 | return downsample 232 | 233 | def __len__(self): 234 | return len(self.elements) 235 | 236 | 237 | def pil_loader(path): 238 | with open(path, 'rb') as f: 239 | with Image.open(f) as img: 240 | return img.convert('RGB') 241 | 242 | 243 | def SUFB_collate(batch): 244 | 245 | ids = [b[0] for b in batch if b[0] is not None] 246 | if ids ==[]: 247 | return None,None,None,None 248 | else: 249 | seqs = [b[1] for b in batch if b[1] is not None] 250 | labels = [b[2] for b in batch if b[2] is not None] 251 | tokens = [b[-1] for b in batch if b[-1] is not None] 252 | 253 | seqs = torch.stack(seqs,dim=0) 254 | labels=torch.stack(labels,dim=0) 255 | tokens = pad_sequence(tokens,batch_first =True) 256 | 257 | if len(batch)>3: 258 | # train or val mode 259 | ptrs = torch.tensor([b[3] for b in batch if b[3] is not None]) 260 | return torch.tensor(ids),seqs,labels,ptrs,tokens 261 | else: 262 | # test mode 263 | return ids,seqs,labels,tokens 264 | 265 | 266 | def collate(batch): 267 | ids = [b[0] for b in batch if b[0] is not None] 268 | seq = [b[1] for b in batch if b[1] is not None] 269 | label = [b[2] for b in batch if b[2] is not None] 270 | tokens = [b[-1] for b in batch if b[-1] is not None] 271 | 272 | if len(seq) ==0: 273 | return None,None,None,None,None 274 | else: 275 | Ks = [s.shape[0] for s in seq] 276 | seq = pad_sequence(seq,batch_first=True) 277 | label=torch.stack(label,dim=0) 278 | tokens = pad_sequence(tokens,batch_first =True) 279 | return ids,seq,label,Ks,tokens 280 | -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import random 4 | from tqdm import tqdm 5 | import numpy as np 6 | from utils.utils import * 7 | from utils.plot_utils import * 8 | from torch import nn 9 | 10 | def train_one_epoch(args,epoch,net,optimizer,trainset,train_loader,SUFB = False): 11 | np.random.seed(epoch) 12 | random.seed(epoch) 13 | net.train() 14 | 15 | data_time = AverageMeter() 16 | batch_time = AverageMeter() 17 | losses = [AverageMeter()] 18 | accuracy = [AverageMeter(),AverageMeter()] 19 | criterion = nn.CrossEntropyLoss(reduction='mean') 20 | 21 | t0 = time.time() 22 | 23 | for j, batch_samples in enumerate(train_loader): 24 | data_time.update(time.time() - t0) 25 | 26 | 27 | # cls_targets: action class labels 28 | # att_targets: attribute labels 29 | if not SUFB: 30 | v_ids, seq, cls_targets, n_clips_per_video, att_targets = batch_samples 31 | if seq is None: 32 | continue 33 | mask = tfm_mask(n_clips_per_video) 34 | preds,cls_preds = net((seq,mask)) 35 | else: 36 | # ptrs: clip pointers, where the online sampled clips start 37 | v_ids, seq, cls_targets, ptrs, att_targets = batch_samples 38 | preds,cls_preds = net((seq,v_ids,ptrs)) 39 | 40 | cls_targets = cls_targets.cuda() 41 | match_acc = multihead_acc(preds, cls_targets, att_targets, \ 42 | trainset.class_tokens, Q = args.num_queries) 43 | 44 | preds = preds.reshape(-1, args.attribute_set_size) 45 | att_targets = att_targets.view(-1).cuda() 46 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0] 47 | 48 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)] 49 | cls_acc, match_acc = torch.cat(acc, 0).mean(0) 50 | 51 | loss = criterion(preds, att_targets) 52 | loss += criterion(cls_preds, cls_targets) 53 | 54 | accuracy[0].update(match_acc.item(), args.batch_size) 55 | accuracy[1].update(cls_acc.item(), args.batch_size) 56 | losses[0].update(loss.item(), args.batch_size) 57 | 58 | optimizer.zero_grad() 59 | loss.backward() 60 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.max_norm) 61 | optimizer.step() 62 | 63 | torch.cuda.empty_cache() 64 | batch_time.update(time.time() - t0) 65 | t0 = time.time() 66 | 67 | if j % (args.print_iter) == 0: 68 | t1 = time.time() 69 | print('Epoch: [{0}][{1}/{2}]\t' 70 | 'Loss {loss[0].val:.4f} Acc: {acc[0].val:.4f}\t' 71 | 'T-data:{dt.val:.2f} T-batch:{bt.val:.2f}\t'.format( 72 | epoch, j, len(train_loader), 73 | loss=losses, acc=accuracy, dt=data_time, bt=batch_time)) 74 | 75 | args.train_plotter.add_data('local/loss', losses[0].local_avg, epoch*len(train_loader)+j) 76 | args.train_plotter.add_data('local/match_acc', accuracy[0].local_avg,epoch*len(train_loader)+j) 77 | args.train_plotter.add_data('local/cls_acc', accuracy[1].local_avg, epoch*len(train_loader)+j) 78 | torch.cuda.empty_cache() 79 | 80 | if epoch % args.save_epoch == 0: 81 | print('Saving state, epoch: %d iter:%d'%(epoch, j)) 82 | save_ckpt(net,optimizer,args.best_acc,epoch,args.save_folder,str(epoch),SUFB) 83 | 84 | save_ckpt(net,optimizer,args.best_acc,epoch,args.save_folder,'latest',SUFB) 85 | 86 | train_acc = [i.avg for i in accuracy] 87 | args.train_plotter.add_data('global/loss', [i.avg for i in losses], epoch) 88 | args.train_plotter.add_data('global/match_acc', accuracy[0].local_avg, epoch) 89 | args.train_plotter.add_data('global/cls_acc', accuracy[1].local_avg, epoch) 90 | 91 | 92 | 93 | 94 | def eval_one_epoch(args,epoch,net,testset,test_loader,SUFB = False): 95 | net.eval() 96 | test_accuracy = [AverageMeter(),AverageMeter()] 97 | np.random.seed(epoch+1) 98 | random.seed(epoch+1) 99 | 100 | with torch.no_grad(): 101 | for k, batch_samples in tqdm(enumerate(test_loader),total=len(test_loader)): 102 | 103 | # cls_targets: action class labels 104 | # att_targets: attribute labels 105 | if not SUFB: 106 | v_ids,seq,cls_targets,n_clips_per_video,att_targets = batch_samples 107 | if seq is None: 108 | continue 109 | mask = tfm_mask(n_clips_per_video) 110 | preds,cls_preds = net((seq,mask)) 111 | else: 112 | 113 | # ptrs: clip pointers, where the online sampled clips start 114 | v_ids,seq,cls_targets,ptrs,att_targets = batch_samples 115 | preds,cls_preds = net((seq,v_ids,ptrs)) 116 | 117 | cls_targets = cls_targets.cuda() 118 | match_acc = multihead_acc(preds,cls_targets, att_targets, \ 119 | testset.class_tokens, Q=args.num_queries) 120 | 121 | preds = preds.reshape(-1,args.attribute_set_size) 122 | att_targets = att_targets.view(-1).cuda() 123 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0] 124 | 125 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)] 126 | cls_acc, match_acc = torch.cat(acc, 0).mean(0) 127 | 128 | test_accuracy[0].update(cls_acc.item(), args.batch_size) 129 | test_accuracy[1].update(match_acc.item(), args.batch_size) 130 | 131 | torch.cuda.empty_cache() 132 | 133 | test_acc = [i.avg for i in test_accuracy] 134 | args.val_plotter.add_data('global/cls_acc',test_acc[0], epoch) 135 | args.val_plotter.add_data('global/match_acc',test_acc[1], epoch) 136 | 137 | 138 | if test_acc[1] > args.best_acc: 139 | args.best_acc = test_acc[1] 140 | torch.save({'model_state_dict': net.state_dict(),\ 141 | 'best_acc':test_acc[1]},\ 142 | args.save_folder + '/' + 'best.pth') 143 | 144 | 145 | 146 | def save_ckpt(net,optimizer,best_acc,epoch,save_folder,name,SUFB): 147 | if SUFB: 148 | torch.save({'model_state_dict': net.state_dict(), 149 | 'optimizer_state_dict': optimizer.state_dict(), 150 | 'queue':net.module.queue, 151 | 'best_acc':best_acc, 152 | 'epoch':epoch}, 153 | save_folder + '/' + name+'.pth') 154 | 155 | else: 156 | torch.save({'model_state_dict': net.state_dict(), 157 | 'optimizer_state_dict': optimizer.state_dict(), 158 | 'best_acc':best_acc, 159 | 'epoch':epoch}, 160 | save_folder + '/' + name+'.pth') 161 | -------------------------------------------------------------------------------- /models/TQN.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | import numpy as np 6 | from torch.nn.utils.rnn import pad_sequence 7 | import torch.nn.functional as F 8 | from .transformer import * 9 | from utils.utils import tfm_mask 10 | 11 | 12 | class BasicConv3d(nn.Module): 13 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0,LayerNorm=False): 14 | super(BasicConv3d, self).__init__() 15 | self.conv = nn.Conv3d(in_planes, out_planes, 16 | kernel_size=kernel_size, stride=stride, 17 | padding=padding, bias=False) 18 | 19 | self.LayerNorm = LayerNorm 20 | if not self.LayerNorm: 21 | self.bn = nn.BatchNorm3d(out_planes) 22 | self.bn.weight.data.fill_(1) 23 | self.bn.bias.data.zero_() 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv.weight.data.fill_(1) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | if not self.LayerNorm: 30 | x = self.bn(x) 31 | x = self.relu(x) 32 | return x 33 | 34 | 35 | class STConv3d(nn.Module): 36 | def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0): 37 | super(STConv3d, self).__init__() 38 | self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size), 39 | stride=(1,stride,stride),padding=(0,padding,padding), bias=False) 40 | self.conv2 = nn.Conv3d(out_planes,out_planes,kernel_size=(kernel_size,1,1), 41 | stride=(stride,1,1),padding=(padding,0,0), bias=False) 42 | 43 | self.bn1=nn.BatchNorm3d(out_planes) 44 | self.bn2=nn.BatchNorm3d(out_planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | 47 | # init 48 | self.conv1.weight.data.normal_(mean=0, std=0.01) 49 | self.conv2.weight.data.normal_(mean=0, std=0.01) 50 | 51 | self.bn1.weight.data.fill_(1) 52 | self.bn1.bias.data.zero_() 53 | self.bn2.weight.data.fill_(1) 54 | self.bn2.bias.data.zero_() 55 | 56 | def forward(self,x): 57 | x=self.conv1(x) 58 | x=self.bn1(x) 59 | x=self.relu(x) 60 | x=self.conv2(x) 61 | x=self.bn2(x) 62 | x=self.relu(x) 63 | return x 64 | 65 | 66 | class SelfGating(nn.Module): 67 | def __init__(self, input_dim): 68 | super(SelfGating, self).__init__() 69 | self.fc = nn.Linear(input_dim, input_dim) 70 | 71 | def forward(self, input_tensor): 72 | """Feature gating as used in S3D-G""" 73 | spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) 74 | weights = self.fc(spatiotemporal_average) 75 | weights = torch.sigmoid(weights) 76 | return weights[:, :, None, None, None] * input_tensor 77 | 78 | 79 | class SepInception(nn.Module): 80 | def __init__(self, in_planes, out_planes, gating=False,LayerNorm=False): 81 | super(SepInception, self).__init__() 82 | 83 | assert len(out_planes) == 6 84 | assert isinstance(out_planes, list) 85 | 86 | [num_out_0_0a, 87 | num_out_1_0a, num_out_1_0b, 88 | num_out_2_0a, num_out_2_0b, 89 | num_out_3_0b] = out_planes 90 | 91 | self.branch0 = nn.Sequential( 92 | BasicConv3d(in_planes, num_out_0_0a, kernel_size=1, stride=1), 93 | ) 94 | self.branch1 = nn.Sequential( 95 | BasicConv3d(in_planes, num_out_1_0a, kernel_size=1, stride=1), 96 | STConv3d(num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1), 97 | ) 98 | self.branch2 = nn.Sequential( 99 | BasicConv3d(in_planes, num_out_2_0a, kernel_size=1, stride=1), 100 | STConv3d(num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1), 101 | ) 102 | self.branch3 = nn.Sequential( 103 | nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), 104 | BasicConv3d(in_planes, num_out_3_0b, kernel_size=1, stride=1,LayerNorm=LayerNorm), 105 | ) 106 | 107 | self.out_channels = sum([num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) 108 | 109 | self.gating = gating 110 | if gating: 111 | self.gating_b0 = SelfGating(num_out_0_0a) 112 | self.gating_b1 = SelfGating(num_out_1_0b) 113 | self.gating_b2 = SelfGating(num_out_2_0b) 114 | self.gating_b3 = SelfGating(num_out_3_0b) 115 | 116 | def forward(self, x): 117 | if isinstance(x,tuple): 118 | x = x[0] 119 | 120 | x0 = self.branch0(x) 121 | x1 = self.branch1(x) 122 | x2 = self.branch2(x) 123 | x3 = self.branch3(x) 124 | if self.gating: 125 | x0 = self.gating_b0(x0) 126 | x1 = self.gating_b1(x1) 127 | x2 = self.gating_b2(x2) 128 | x3 = self.gating_b3(x3) 129 | out = torch.cat((x0, x1, x2, x3), 1) 130 | return out 131 | 132 | 133 | 134 | 135 | class TQN(nn.Module): 136 | 137 | def __init__(self, args,first_channel=3,features_out =False,gating=False,SUFB=False,mode='train'): 138 | super(TQN, self).__init__() 139 | 140 | self.gating = gating 141 | self.features_out = features_out 142 | self.d_model = args.d_model 143 | self.SUFB = SUFB 144 | self.mode = mode 145 | 146 | if SUFB: 147 | self.K =args.K 148 | 149 | ################################### 150 | '''S3D''' 151 | ################################### 152 | 153 | self.Conv_1a = STConv3d(first_channel, 64, kernel_size=7, stride=2, padding=3) 154 | self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) 155 | 156 | self.MaxPool_2a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 157 | self.Conv_2b = BasicConv3d(64, 64, kernel_size=1, stride=1) 158 | self.Conv_2c = STConv3d(64, 192, kernel_size=3, stride=1, padding=1) 159 | 160 | self.block2 = nn.Sequential( 161 | self.MaxPool_2a, # (64, 32, 56, 56) 162 | self.Conv_2b, # (64, 32, 56, 56) 163 | self.Conv_2c) # (192, 32, 56, 56) 164 | 165 | 166 | self.MaxPool_3a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 167 | self.Mixed_3b = SepInception(in_planes=192, out_planes=[64, 96, 128, 16, 32, 32], gating=gating) 168 | self.Mixed_3c = SepInception(in_planes=256, out_planes=[128, 128, 192, 32, 96, 64], gating=gating) 169 | 170 | self.block3 = nn.Sequential( 171 | self.MaxPool_3a, # (192, 32 , 28, 28) 172 | self.Mixed_3b, # (256, 32, 28, 28) 173 | self.Mixed_3c) # (480, 32, 28, 28) 174 | 175 | self.MaxPool_4a = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) 176 | self.Mixed_4b = SepInception(in_planes=480, out_planes=[192, 96, 208, 16, 48, 64], gating=gating) 177 | self.Mixed_4c = SepInception(in_planes=512, out_planes=[160, 112, 224, 24, 64, 64], gating=gating) 178 | self.Mixed_4d = SepInception(in_planes=512, out_planes=[128, 128, 256, 24, 64, 64], gating=gating) 179 | self.Mixed_4e = SepInception(in_planes=512, out_planes=[112, 144, 288, 32, 64, 64], gating=gating) 180 | self.Mixed_4f = SepInception(in_planes=528, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 181 | 182 | self.block4 = nn.Sequential( 183 | self.MaxPool_4a, # (480, 16, 14, 14) 184 | self.Mixed_4b, # (512, 16, 14, 14) 185 | self.Mixed_4c, # (512, 16, 14, 14) 186 | self.Mixed_4d, # (512, 16, 14, 14) 187 | self.Mixed_4e, # (528, 16, 14, 14) 188 | self.Mixed_4f) # (832, 16, 14, 14) 189 | 190 | self.MaxPool_5a = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) 191 | self.Mixed_5b = SepInception(in_planes=832, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 192 | self.Mixed_5c = SepInception(in_planes=832, out_planes=[384, 192, 384, 48, 128, 128], gating=gating) 193 | 194 | self.block5 = nn.Sequential( 195 | self.MaxPool_5a, # (832, 8, 7, 7) 196 | self.Mixed_5b, # (832, 8, 7, 7) 197 | self.Mixed_5c) # (1024, 8, 7, 7) 198 | 199 | self.AvgPool_0a = nn.AvgPool3d(kernel_size=(1, 7, 7), stride=1) 200 | 201 | 202 | 203 | ################################### 204 | ''' Query Decoder''' 205 | ################################### 206 | 207 | if not self.features_out: 208 | 209 | # Decoder Layers 210 | self.H = args.H 211 | decoder_layer = TransformerDecoderLayer(self.d_model, args.H, 1024, 212 | 0.1, 'relu',normalize_before=True) 213 | decoder_norm = nn.LayerNorm(self.d_model) 214 | self.decoder = TransformerDecoder(decoder_layer, args.N, decoder_norm, 215 | return_intermediate=False) 216 | 217 | # Learnable Queries 218 | self.query_embed = nn.Embedding(args.num_queries,self.d_model) 219 | self.dropout_feas = nn.Dropout(args.dropout) 220 | 221 | # Attribute classifier 222 | self.classifier = nn.Linear(self.d_model,args.attribute_set_size) 223 | 224 | # Class classifier 225 | self.cls_classifier = nn.Linear(self.d_model,args.num_classes) 226 | 227 | 228 | self.apply(self._init_weights) 229 | 230 | 231 | 232 | 233 | def forward(self, input): 234 | 235 | ''' Reshape Input Sequences ''' 236 | if not self.SUFB: 237 | x, mask = input 238 | if len(x.shape) ==5: 239 | # the First stage training 240 | BK, C, T, H, W =x.shape 241 | seg_per_video = mask.shape[-1] - mask.sum(1) 242 | 243 | else: 244 | # Feature extraction mode for full video sequence 245 | B, K, C, T, H, W = x.shape 246 | x = x.reshape(B*K,C,T,H,W) 247 | seg_per_video = None 248 | 249 | else: 250 | # Training with a Stochastically Updated Feature Bank 251 | x, vids, ptrs = input 252 | B, K, C, T, H, W = x.shape 253 | x = x.reshape(B*K,C,T,H,W) 254 | seg_per_video = None 255 | 256 | 257 | ''' Visual Backbone ''' 258 | x = self.block1(x) 259 | x = self.block2(x) 260 | x = self.block3(x) 261 | x = self.block4(x) 262 | x = self.block5(x) 263 | 264 | features = self.AvgPool_0a(x).squeeze() 265 | 266 | if self.SUFB: 267 | features,Ts,mask = self.fill_SUFB(features,vids,ptrs) 268 | 269 | if self.features_out: 270 | return features 271 | 272 | else: 273 | ''' Query Decoder ''' 274 | if seg_per_video is not None: 275 | # first stage training 276 | features = self.reshape_features(features.squeeze(), 277 | seg_per_video) 278 | B = len(seg_per_video) 279 | K = int(BK // B) 280 | 281 | elif not self.SUFB: 282 | features = features.reshape(B,K,-1) 283 | 284 | if mask is not None: 285 | mask = mask.view(B,-1) 286 | 287 | features = features.transpose(0,1) 288 | query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, B, 1) 289 | features = self.decoder(query_embed, features, 290 | memory_key_padding_mask=mask, pos=None, query_pos=None) 291 | 292 | out = self.dropout_feas(features) # [T,B,C] 293 | x= self.classifier(out[:-1]).transpose(0,1) 294 | x_cls = self.cls_classifier(out[-1]) 295 | 296 | return x, x_cls 297 | 298 | 299 | def reshape_features(self,features,seg_per_video): 300 | reshaped_features = [] 301 | counter = 0 302 | for n_seg in seg_per_video: 303 | reshaped_features.append(features[counter:counter+n_seg]) 304 | counter += n_seg 305 | return pad_sequence(reshaped_features,batch_first=True) 306 | 307 | 308 | def fill_SUFB(self,features,vids,ptrs): 309 | fea_dim = features.shape[-1] 310 | 311 | if self.mode =='train': 312 | # Update newly computed features in the SUFB, 313 | # And read all the features from the SUFB 314 | full_features = [] 315 | features = features.view(-1,self.K,fea_dim) 316 | features_split = torch.split(features, 1, dim=0) 317 | 318 | for f, vid, ptr in zip(features_split, vids, ptrs): 319 | vid = vid.item() 320 | end = min([len(self.queue[vid]), ptr + self.K]) 321 | 322 | self.queue[vid][ptr:end] = f[0,:(end-ptr),:] 323 | full_features.append(self.queue[vid]) 324 | self.queue[vid] = self.queue[vid].detach() 325 | 326 | 327 | Ts = [f.shape[0] for f in full_features] 328 | mask = tfm_mask(Ts).cuda() 329 | features = pad_sequence(full_features,batch_first=True).cuda() 330 | 331 | 332 | elif self.mode == 'test': 333 | # Test mode, compute all features online 334 | features = features.view(B,-1,fea_dim).cuda() 335 | Ts = [features[i].shape[0] for i in range(B)] 336 | mask = tfm_mask(Ts).cuda() 337 | 338 | return features,Ts,mask 339 | 340 | 341 | @staticmethod 342 | def _init_weights(module): 343 | r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" 344 | 345 | if isinstance(module, nn.Linear): 346 | module.weight.data.normal_(mean=0.0, std=0.02) 347 | 348 | elif isinstance(module, nn.MultiheadAttention): 349 | module.in_proj_weight.data.normal_(mean=0.0, std=0.02) 350 | module.out_proj.weight.data.normal_(mean=0.0, std=0.02) 351 | 352 | elif isinstance(module, nn.Embedding): 353 | module.weight.data.normal_(mean=0.0, std=0.02) 354 | if module.padding_idx is not None: 355 | module.weight.data[module.padding_idx].zero_() 356 | 357 | 358 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from DETR tranformer: 3 | https://github.com/facebookresearch/detr 4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 5 | 6 | """ 7 | 8 | import copy 9 | from typing import Optional, List 10 | import pickle as cp 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn, Tensor 15 | 16 | 17 | class TransformerDecoder(nn.Module): 18 | 19 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 20 | super().__init__() 21 | self.layers = _get_clones(decoder_layer, num_layers) 22 | self.num_layers = num_layers 23 | self.norm = norm 24 | self.return_intermediate = return_intermediate 25 | 26 | def forward(self, tgt, memory, 27 | tgt_mask: Optional[Tensor] = None, 28 | memory_mask: Optional[Tensor] = None, 29 | tgt_key_padding_mask: Optional[Tensor] = None, 30 | memory_key_padding_mask: Optional[Tensor] = None, 31 | pos: Optional[Tensor] = None, 32 | query_pos: Optional[Tensor] = None): 33 | output = tgt 34 | T,B,C = memory.shape 35 | intermediate = [] 36 | 37 | for n,layer in enumerate(self.layers): 38 | 39 | residual=True 40 | output,ws = layer(output, memory, tgt_mask=tgt_mask, 41 | memory_mask=memory_mask, 42 | tgt_key_padding_mask=tgt_key_padding_mask, 43 | memory_key_padding_mask=memory_key_padding_mask, 44 | pos=pos, query_pos=query_pos,residual=residual) 45 | 46 | if self.return_intermediate: 47 | intermediate.append(self.norm(output)) 48 | if self.norm is not None: 49 | output = self.norm(output) 50 | if self.return_intermediate: 51 | intermediate.pop() 52 | intermediate.append(output) 53 | 54 | if self.return_intermediate: 55 | return torch.stack(intermediate) 56 | return output 57 | 58 | 59 | 60 | class TransformerDecoderLayer(nn.Module): 61 | 62 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 63 | activation="relu", normalize_before=False): 64 | super().__init__() 65 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 66 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 67 | # Implementation of Feedforward model 68 | self.linear1 = nn.Linear(d_model, dim_feedforward) 69 | self.dropout = nn.Dropout(dropout) 70 | self.linear2 = nn.Linear(dim_feedforward, d_model) 71 | 72 | self.norm1 = nn.LayerNorm(d_model) 73 | self.norm2 = nn.LayerNorm(d_model) 74 | self.norm3 = nn.LayerNorm(d_model) 75 | self.dropout1 = nn.Dropout(dropout) 76 | self.dropout2 = nn.Dropout(dropout) 77 | self.dropout3 = nn.Dropout(dropout) 78 | 79 | self.activation = _get_activation_fn(activation) 80 | self.normalize_before = normalize_before 81 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 82 | return tensor if pos is None else tensor + pos 83 | 84 | def forward_post(self, tgt, memory, 85 | tgt_mask: Optional[Tensor] = None, 86 | memory_mask: Optional[Tensor] = None, 87 | tgt_key_padding_mask: Optional[Tensor] = None, 88 | memory_key_padding_mask: Optional[Tensor] = None, 89 | pos: Optional[Tensor] = None, 90 | query_pos: Optional[Tensor] = None, 91 | residual=True): 92 | q = k = self.with_pos_embed(tgt, query_pos) 93 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 94 | key_padding_mask=tgt_key_padding_mask) 95 | tgt = self.norm1(tgt) 96 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 97 | key=self.with_pos_embed(memory, pos), 98 | value=memory, attn_mask=memory_mask, 99 | key_padding_mask=memory_key_padding_mask) 100 | 101 | 102 | # attn_weights [B,NUM_Q,T] 103 | tgt = tgt + self.dropout2(tgt2) 104 | tgt = self.norm2(tgt) 105 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 106 | tgt = tgt + self.dropout3(tgt2) 107 | tgt = self.norm3(tgt) 108 | return tgt,ws 109 | 110 | def forward_pre(self, tgt, memory, 111 | tgt_mask: Optional[Tensor] = None, 112 | memory_mask: Optional[Tensor] = None, 113 | tgt_key_padding_mask: Optional[Tensor] = None, 114 | memory_key_padding_mask: Optional[Tensor] = None, 115 | pos: Optional[Tensor] = None, 116 | query_pos: Optional[Tensor] = None): 117 | tgt2 = self.norm1(tgt) 118 | q = k = self.with_pos_embed(tgt2, query_pos) 119 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 120 | key_padding_mask=tgt_key_padding_mask) 121 | tgt = tgt + self.dropout1(tgt2) 122 | tgt2 = self.norm2(tgt) 123 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 124 | key=self.with_pos_embed(memory, pos), 125 | value=memory, attn_mask=memory_mask, 126 | key_padding_mask=memory_key_padding_mask) 127 | tgt = tgt + self.dropout2(tgt2) 128 | tgt2 = self.norm3(tgt) 129 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 130 | tgt = tgt + self.dropout3(tgt2) 131 | return tgt,ws 132 | 133 | def forward(self, tgt, memory, 134 | tgt_mask: Optional[Tensor] = None, 135 | memory_mask: Optional[Tensor] = None, 136 | tgt_key_padding_mask: Optional[Tensor] = None, 137 | memory_key_padding_mask: Optional[Tensor] = None, 138 | pos: Optional[Tensor] = None, 139 | query_pos: Optional[Tensor] = None, 140 | residual=True): 141 | if self.normalize_before: 142 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 143 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 144 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 145 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) 146 | 147 | 148 | def _get_clones(module, N): 149 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 150 | 151 | 152 | 153 | def _get_activation_fn(activation): 154 | """Return an activation function given a string""" 155 | if activation == "relu": 156 | return F.relu 157 | if activation == "gelu": 158 | return F.gelu 159 | if activation == "glu": 160 | return F.glu 161 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb==0.13.7 2 | scipy==1.5.4 3 | six==1.16.0 4 | tensorboardX==2.2 5 | torch==1.8.1 6 | torchvision==0.9.1 7 | tqdm==4.60.0 8 | -------------------------------------------------------------------------------- /scripts/construct_SUFB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import time 5 | import torch 6 | import torchvision 7 | import sys 8 | sys.path.append('../') 9 | import numpy as np 10 | import _pickle as cp 11 | import os.path as osp 12 | import torch.nn as nn 13 | import utils.augmentation as A 14 | import torch.utils.data as data 15 | import torch.backends.cudnn as cudnn 16 | import json 17 | import glob 18 | 19 | from tqdm import tqdm 20 | from torchvision import transforms 21 | from models.TQN import TQN 22 | 23 | from data.dataloader import TQN_dataloader,SUFB_collate 24 | 25 | 26 | 27 | 28 | 29 | def worker_init_fn(worker_id): 30 | np.random.seed(np.random.get_state()[1][0] + worker_id) 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--name', default='initial', type=str) 35 | 36 | ## data setting 37 | parser.add_argument('--dataset', default='gym99', type=str) 38 | 39 | parser.add_argument('--img_dim', default=224, type=int) 40 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block') 41 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate') 42 | parser.add_argument('--batch_size', default=1, type=int) 43 | parser.add_argument('--resume_file', default='', type=str) 44 | parser.add_argument('--d_model', default=1024, type=int) 45 | parser.add_argument('--dataset_config', default='', type=str) 46 | 47 | parser.add_argument('--all_frames', action='store_true') 48 | parser.add_argument('--seed', default=0, type=int) 49 | parser.add_argument('--out_dir', default='', type=str) 50 | 51 | # device params 52 | parser.add_argument("--gpus", dest="gpu", default="0", type=str) 53 | parser.add_argument('--num_workers', default=16, type=int) 54 | 55 | ## model setting 56 | parser.add_argument("--model",default='s3d',type=str,help='') 57 | parser.add_argument('--resume', default=-1, type=int) 58 | parser.add_argument('--dropout', default=0.2, type=float) 59 | 60 | ## frequency setting 61 | parser.add_argument('--eval_epoch', default=5, type=int) 62 | parser.add_argument('--max_iter', default=20000000, type=int) 63 | 64 | 65 | args = parser.parse_args() 66 | if args.dataset_config is not None: 67 | d = vars(args) 68 | with open(args.dataset_config, "r") as f: 69 | cfg = json.load(f) 70 | d.update(cfg) 71 | 72 | args.max_length = 1e6 73 | torch.manual_seed(args.seed) 74 | np.random.seed(args.seed) 75 | random.seed(args.seed) 76 | 77 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 78 | device = torch.device("cuda") 79 | torch.set_default_tensor_type('torch.FloatTensor') 80 | 81 | ## Set Up Model 82 | 83 | num_classes =int(''.join([s for s in args.dataset if s.isdigit()])) 84 | net = TQN(args,features_out=True) 85 | net = torch.nn.DataParallel(net).to(device) 86 | 87 | ## Load Model Weights 88 | 89 | assert args.resume_file!= '' 90 | checkpoint = torch.load(args.resume_file) 91 | state_dict = checkpoint['model_state_dict'] 92 | net.load_state_dict(state_dict,strict=False) 93 | 94 | ## Set Up Dataloader 95 | 96 | transform = transforms.Compose([ 97 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.6,p=0.8), 98 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len), 99 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25, 100 | p=1.0, consistent=False, clip_len=args.clip_len), 101 | A.ToTensor(), 102 | A.Normalize(args.dataset)]) 103 | transform_test = transforms.Compose([ 104 | A.CenterCrop(size=args.img_dim), 105 | A.ToTensor(), 106 | A.Normalize(args.dataset)]) 107 | 108 | trainset = TQN_dataloader(args,transform=transform,mode='train') 109 | testset = TQN_dataloader(args,transform=transform_test,mode='val') 110 | 111 | for dataset in [trainset,testset]: 112 | data_loader = data.DataLoader(dataset, args.batch_size,num_workers=args.num_workers, 113 | collate_fn =SUFB_collate,pin_memory=True, worker_init_fn=worker_init_fn,drop_last=False) 114 | 115 | 116 | cudnn.benchmark = True 117 | net.eval() 118 | 119 | with torch.no_grad(): 120 | for k, test_samples in tqdm(enumerate(data_loader),total=len(data_loader)): 121 | 122 | v_id, seq, target, _ = test_samples 123 | if v_id is None: 124 | continue 125 | B, K, C, T, H, W =seq.shape # [batch_size, num_clips, num_channels, clip_len, H, W] 126 | out_pkl = osp.join(args.out_dir,v_id[0]+'.pkl') 127 | 128 | if not osp.exists(osp.join(args.out_dir)): 129 | os.mkdir(osp.join(args.out_dir)) 130 | 131 | # Clip super long videos to fit it in one/two gpus 132 | if seq.shape[-3] >600: 133 | seq = seq[:,:,int(0.2*K):-int(0.2*K):,:,:] 134 | 135 | # Forward 136 | feas = net((seq,None)) 137 | feas = feas.squeeze().view(B,-1,feas.shape[-1]) 138 | 139 | # Save individual feature files first 140 | with open(out_pkl, 'wb') as f: 141 | cp.dump(feas.cpu(),f) 142 | 143 | 144 | ## Write All the Feature Files into One File 145 | vid_to_id,features_dict = {}, {} 146 | pkls = glob.glob(osp.join(args.out_dir,'*.pkl')) 147 | 148 | for ind,pkl in enumerate(pkls): 149 | v_id = osp.basename(pkl).replace('.pkl','') 150 | vid_to_id[v_id] = ind 151 | features_dict[ind] = cp.load(open(pkl,'rb'))[0] 152 | 153 | with open(osp.join(args.out_dir,args.dataset+'_all_vid2id.pkl'), 'wb') as f: 154 | cp.dump(vid_to_id,f) 155 | 156 | with open(osp.join(args.out_dir,args.dataset+'_all_features.pkl'), 'wb') as f: 157 | cp.dump(features_dict,f) 158 | 159 | print('Saved featrues from ',len(features_dict),' video samples.') 160 | 161 | if __name__ == '__main__': 162 | main() 163 | 164 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import random 4 | import json 5 | import os 6 | import sys 7 | sys.path.append('../') 8 | 9 | from tqdm import tqdm 10 | from torch import nn 11 | import torch.optim as optim 12 | import torch.utils.data as data 13 | import torch.backends.cudnn as cudnn 14 | import utils.augmentation as A 15 | import os.path as osp 16 | 17 | from torch.utils.data import DataLoader 18 | from torchvision import transforms 19 | 20 | from models.TQN import * 21 | from utils.utils import make_dirs,multihead_acc,calc_topk_accuracy 22 | from utils.plot_utils import * 23 | 24 | from data.dataloader import TQN_dataloader,SUFB_collate 25 | 26 | 27 | def worker_init_fn(worker_id): 28 | np.random.seed(np.random.get_state()[1][0] + worker_id) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--name', default='initial', type=str) 33 | 34 | ## data setting 35 | parser.add_argument('--dataset', default='', type=str) 36 | parser.add_argument('--img_dim', default=224, type=int) 37 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block') 38 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate') 39 | parser.add_argument('--batch_size', default=1, type=int) 40 | parser.add_argument('--root', default='', type=str) 41 | parser.add_argument('--dataset_config', default='', type=str) 42 | parser.add_argument('--feature_file', default='', type=str) 43 | parser.add_argument('--seed', default=0, type=int) 44 | 45 | # device params 46 | parser.add_argument("--gpus", dest="gpu", default="0", type=str) 47 | parser.add_argument('--num_workers', default=16, type=int) 48 | 49 | ## model setting 50 | parser.add_argument("--model",default='s3d',type=str,help='i3d,s3d') 51 | parser.add_argument('--dropout', default=0.5, type=float) 52 | 53 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder') 54 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder') 55 | parser.add_argument('--K', default=2, type=int,help='Number of clips updated per batch') 56 | 57 | parser.add_argument('--d_model', default=1024, type=int) 58 | parser.add_argument('--pretrained_weights_path', default='', type=str) 59 | 60 | 61 | args = parser.parse_args() 62 | 63 | if args.dataset_config is not None: 64 | d = vars(args) 65 | with open(args.dataset_config, "r") as f: 66 | cfg = json.load(f) 67 | d.update(cfg) 68 | 69 | assert args.batch_size == 1 70 | 71 | make_dirs(args) 72 | torch.manual_seed(args.seed) 73 | np.random.seed(args.seed) 74 | random.seed(args.seed) 75 | 76 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 77 | device = torch.device("cuda") 78 | torch.set_default_tensor_type('torch.FloatTensor') 79 | 80 | ## Set Up Model 81 | 82 | net = TQN(args).cuda() 83 | net = torch.nn.parallel.DataParallel(net) 84 | 85 | ## Load Model Weights 86 | 87 | resume_file = osp.join(args.save_folder,'best.pth') 88 | checkpoint = torch.load(resume_file) 89 | net.load_state_dict(checkpoint['model_state_dict'],strict=True) 90 | 91 | ## Set Up Dataloader 92 | 93 | transform_test = transforms.Compose([ 94 | A.CenterCrop(size=args.img_dim), 95 | A.ToTensor(), 96 | A.Normalize(args.dataset)]) 97 | testset = TQN_dataloader(args,mode='test', 98 | transform=transform_test, 99 | SUFB = False) 100 | test_loader = data.DataLoader(testset, args.batch_size,num_workers=args.num_workers, 101 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=False, 102 | collate_fn = SUFB_collate,drop_last=True,sampler = None) 103 | 104 | 105 | net.eval() 106 | test_accuracy = [AverageMeter(),AverageMeter()] 107 | 108 | with torch.no_grad(): 109 | 110 | for k, test_samples in tqdm(enumerate(test_loader),total=len(test_loader)): 111 | 112 | v_ids, seqs, cls_targets, att_targets = test_samples 113 | 114 | seqs = seqs[0] 115 | B, K, C, T, H, W = seqs.shape 116 | cls_targets = cls_targets.cuda() 117 | att_targets = att_targets.view(-1).cuda() 118 | 119 | preds, cls_preds = net((seqs,None)) 120 | preds = torch.softmax(preds, dim=-1).mean(0, keepdim=True) 121 | 122 | cls_preds = torch.softmax(cls_preds, dim=-1).mean(0, keepdim=True) 123 | match_acc = multihead_acc(preds, cls_targets, att_targets, \ 124 | testset.class_tokens, Q = args.num_queries) 125 | 126 | cls_acc = calc_topk_accuracy(cls_preds, cls_targets, (1,))[0] 127 | acc = [torch.stack([cls_acc, match_acc], 0).unsqueeze(0)] 128 | 129 | cls_acc, match_acc = torch.cat(acc, 0).mean(0) 130 | 131 | test_accuracy[0].update(cls_acc.item(), 1) 132 | test_accuracy[1].update(match_acc.item(), 1) 133 | 134 | 135 | test_acc = [i.avg for i in test_accuracy] 136 | print("attribute_match_acc:%.2f"% round(test_acc[1]*100, 2)) 137 | print("class_token_acc:%.2f" % round(test_acc[0]*100, 2)) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | 143 | 144 | -------------------------------------------------------------------------------- /scripts/train_1st_stage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import random 4 | import json 5 | import os 6 | import sys 7 | sys.path.append('../') 8 | 9 | from torch import nn 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import torch.backends.cudnn as cudnn 13 | import utils.augmentation as A 14 | import os.path as osp 15 | 16 | from torch.utils.data import DataLoader 17 | from tensorboardX import SummaryWriter 18 | from torchvision import transforms 19 | 20 | from models.TQN import * 21 | from utils.utils import make_dirs 22 | from utils.plot_utils import * 23 | 24 | from engine.engine import train_one_epoch, eval_one_epoch 25 | from data.dataloader import TQN_dataloader,collate 26 | 27 | def worker_init_fn(worker_id): 28 | np.random.seed(np.random.get_state()[1][0] + worker_id) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--name', default='initial', type=str) 33 | 34 | ## data setting 35 | parser.add_argument('--dataset', default='', type=str) 36 | parser.add_argument('--img_dim', default=224, type=int) 37 | parser.add_argument('--clip_len', default=8, type=int, help='Number of frames sampled for each clip') 38 | parser.add_argument('--downsample', default=2, type=int, help='Frame downsampling rate') 39 | parser.add_argument('--batch_size', default=24, type=int) 40 | parser.add_argument('--root', default='', type=str) 41 | parser.add_argument('--dataset_config', default='', type=str) 42 | parser.add_argument('--pretrained', default='k400', type=str) 43 | parser.add_argument('--seed', default=0, type=int) 44 | 45 | # device params 46 | parser.add_argument("--gpus", dest="gpu", default="0", type=str) 47 | parser.add_argument('--num_workers', default=8, type=int) 48 | 49 | ## model setting 50 | parser.add_argument("--model",default='s3d') 51 | parser.add_argument('--resume', default='', type=str) 52 | parser.add_argument('--dropout', default=0.8, type=float) 53 | 54 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder') 55 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder') 56 | parser.add_argument('--d_model', default=1024, type=int) 57 | parser.add_argument('--num_queries', default=0, type=int) 58 | parser.add_argument('--pretrained_weights_path', default='', type=str) 59 | 60 | ## optim setting 61 | parser.add_argument('--lr', default=0.001, type=float) 62 | parser.add_argument('--momentum', default=0.9, type=float) 63 | parser.add_argument('--weight_decay', default=1e-5, type=float) 64 | parser.add_argument('--optim', default='adam', type=str, help='sgd, adam, adadelta') 65 | parser.add_argument('--max_norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients') 66 | parser.add_argument('--max_epoches', default=100, type=int) 67 | 68 | parser.add_argument('--lr_steps', default=[10000, 200000], type=float, nargs="+", 69 | metavar='LRSteps', help='epochs to decay learning rate by 10') 70 | parser.add_argument('--best_acc', default=0, type=float) 71 | 72 | ## frequency setting 73 | parser.add_argument('--print_iter', default=5, type=int) 74 | parser.add_argument('--eval_epoch', default=1, type=int) 75 | parser.add_argument('--save_epoch', default=5, type=int) 76 | parser.add_argument('--save_folder', default='/users/czhang/data/FineGym/exps/github', type=str) 77 | parser.add_argument('--tbx_folder', default='/users/czhang/data/FineGym/tbx/github', type=str) 78 | parser.add_argument('--max_iter', default=20000000, type=int) 79 | 80 | 81 | args = parser.parse_args() 82 | if args.dataset_config is not None: 83 | d = vars(args) 84 | with open(args.dataset_config, "r") as f: 85 | cfg = json.load(f) 86 | d.update(cfg) 87 | 88 | make_dirs(args) 89 | torch.manual_seed(args.seed) 90 | np.random.seed(args.seed) 91 | random.seed(args.seed) 92 | 93 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 94 | device = torch.device("cuda") 95 | torch.set_default_tensor_type('torch.FloatTensor') 96 | 97 | ## Set Up Model 98 | 99 | net = TQN(args) 100 | net = torch.nn.DataParallel(net).to(device) 101 | num_param = sum(p.numel() for p in net.parameters()) 102 | 103 | ## Load Model Weights 104 | 105 | if args.resume != '': 106 | # Resume from a checkpoint 107 | resume_file = osp.join(args.save_folder,str(args.resume)+'.pth') 108 | checkpoint = torch.load(resume_file) 109 | state_dict = checkpoint['model_state_dict'] 110 | net.load_state_dict(state_dict,strict=True) 111 | args.best_acc = checkpoint['best_acc'] 112 | resume_epoch = checkpoint['epoch'] 113 | 114 | else: 115 | # load pretrained weights on K400 116 | state_dict = torch.load(args.pretrained_weights_path) 117 | new_dict = {} 118 | for k,v in state_dict.items(): 119 | k = 'module.'+k 120 | new_dict[k] = v 121 | net.load_state_dict(new_dict,strict=False) 122 | resume_epoch = -1 123 | 124 | 125 | ## Set Up Dataloader 126 | 127 | transform = transforms.Compose([ 128 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.7,p=0.8), 129 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len), 130 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25, 131 | p=1.0, consistent=False, clip_len=args.clip_len), 132 | A.ToTensor(), 133 | A.Normalize(dataset=args.dataset)]) 134 | 135 | transform_test = transforms.Compose([ 136 | A.CenterCrop(size=args.img_dim), 137 | A.ToTensor(), 138 | A.Normalize(dataset=args.dataset)]) 139 | 140 | trainset = TQN_dataloader(args, 141 | transform=transform, mode='train', 142 | ) 143 | testset = TQN_dataloader(args, 144 | transform=transform_test,mode='val', 145 | ) 146 | 147 | 148 | train_loader = data.DataLoader( 149 | trainset, args.batch_size,num_workers=args.num_workers, 150 | pin_memory=True, worker_init_fn=worker_init_fn,shuffle =True, 151 | drop_last=True,collate_fn = collate) 152 | test_loader = data.DataLoader( 153 | testset, args.batch_size,num_workers=args.num_workers, 154 | pin_memory=True, worker_init_fn=worker_init_fn, 155 | collate_fn = collate,drop_last=True) 156 | 157 | 158 | ## Set Up Optimizer 159 | 160 | parameters = net.parameters() 161 | params = [] 162 | for name, param in net.named_parameters(): 163 | if 'attention' in name or 'decoder' in name : 164 | params.append({'params': param, 'lr':args.lr/10}) 165 | else: 166 | params.append({'params': param, 'lr':args.lr}) 167 | 168 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 169 | if args.resume != '': 170 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 171 | 172 | ## Set Up Tensorboard 173 | 174 | writer_val = SummaryWriter(logdir=osp.join(args.tbx_dir,'val')) 175 | writer_train = SummaryWriter(logdir=osp.join(args.tbx_dir, 'train')) 176 | 177 | args.val_plotter = PlotterThread(writer_val) 178 | args.train_plotter = PlotterThread(writer_train) 179 | 180 | 181 | ## Start Training 182 | cudnn.benchmark = True 183 | 184 | for epoch in range(args.max_epoches): 185 | if epoch <= resume_epoch: 186 | continue 187 | train_one_epoch(args,epoch,net,optimizer,trainset,train_loader) 188 | 189 | if epoch % args.eval_epoch == 0: 190 | eval_one_epoch(args,epoch,net,testset,test_loader) 191 | 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | 197 | -------------------------------------------------------------------------------- /scripts/train_2nd_stage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import random 4 | import json 5 | import os 6 | import sys 7 | sys.path.append('../') 8 | 9 | from torch import nn 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import torch.backends.cudnn as cudnn 13 | import utils.augmentation as A 14 | import os.path as osp 15 | 16 | from torch.utils.data import DataLoader 17 | from tensorboardX import SummaryWriter 18 | from torchvision import transforms 19 | 20 | from models.TQN import * 21 | from utils.utils import make_dirs 22 | from utils.plot_utils import * 23 | 24 | from engine.engine import train_one_epoch, eval_one_epoch 25 | from data.dataloader import TQN_dataloader,SUFB_collate 26 | 27 | 28 | 29 | def worker_init_fn(worker_id): 30 | np.random.seed(np.random.get_state()[1][0] + worker_id) 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--name', default='initial', type=str) 36 | 37 | ## data setting 38 | parser.add_argument('--dataset', default='', type=str) 39 | parser.add_argument('--img_dim', default=224, type=int) 40 | parser.add_argument('--clip_len', default=8, type=int, help='number of frames in each video block') 41 | parser.add_argument('--downsample', default=2, type=int, help='frame down sampling rate') 42 | parser.add_argument('--batch_size', default=24, type=int) 43 | parser.add_argument('--root', default='', type=str) 44 | parser.add_argument('--dataset_config', default='', type=str) 45 | parser.add_argument('--feature_file', default='', type=str) 46 | parser.add_argument('--seed', default=0, type=int) 47 | 48 | # device params 49 | parser.add_argument("--gpus", dest="gpu", default="0", type=str) 50 | parser.add_argument('--num_workers', default=16, type=int) 51 | 52 | ## model setting 53 | parser.add_argument("--model",default='s3d',type=str,help='i3d,s3d') 54 | parser.add_argument('--resume', default=-1, type=int) 55 | parser.add_argument('--dropout', default=0.8, type=float) 56 | 57 | parser.add_argument('--N', default=4, type=int,help='Number of layers in the temporal decoder') 58 | parser.add_argument('--H', default=4, type=int,help='Number of heads in the temporal decoder') 59 | parser.add_argument('--K', default=2, type=int,help='Number of clips updated per batch') 60 | 61 | parser.add_argument('--d_model', default=1024, type=int) 62 | parser.add_argument('--pretrained_weights_path', default='', type=str) 63 | 64 | ## optim setting 65 | parser.add_argument('--lr', default=0.001, type=float) 66 | parser.add_argument('--momentum', default=0.9, type=float) 67 | parser.add_argument('--weight_decay', default=1e-5, type=float) 68 | parser.add_argument('--optim', default='adam', type=str, help='sgd, adam, adadelta') 69 | parser.add_argument('--max_norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients') 70 | parser.add_argument('--max_epoches', default=1000000, type=int) 71 | parser.add_argument('--best_acc', default=0, type=float) 72 | 73 | 74 | ## frequency setting 75 | parser.add_argument('--print_iter', default=5, type=int) 76 | parser.add_argument('--eval_epoch', default=1, type=int) 77 | parser.add_argument('--save_epoch', default=5, type=int) 78 | 79 | parser.add_argument('--max_iter', default=20000000, type=int) 80 | parser.add_argument('--lr_steps', default=[10, 20], type=float, nargs="+", 81 | metavar='LRSteps', help='epochs to decay learning rate by 10') 82 | 83 | 84 | args = parser.parse_args() 85 | if args.dataset_config is not None: 86 | d = vars(args) 87 | with open(args.dataset_config, "r") as f: 88 | cfg = json.load(f) 89 | d.update(cfg) 90 | 91 | make_dirs(args) 92 | vid2id = cp.load(open(osp.join(args.root,args.feature_file).replace('features','vid2id'),'rb')) 93 | 94 | id2vid = {} 95 | for vid in vid2id.keys(): 96 | id2vid[vid2id[vid]]=vid 97 | 98 | torch.manual_seed(args.seed) 99 | np.random.seed(args.seed) 100 | random.seed(args.seed) 101 | 102 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 103 | device = torch.device("cuda") 104 | torch.set_default_tensor_type('torch.FloatTensor') 105 | 106 | ## Set Up Model 107 | 108 | net = TQN(args,SUFB=True).cuda() 109 | net = torch.nn.parallel.DataParallel(net) 110 | 111 | 112 | ## Load Model Weights 113 | 114 | if args.resume != -1: 115 | resume_file = osp.join(args.save_folder,str(args.resume)+'.pth') 116 | checkpoint = torch.load(resume_file) 117 | state_dict = checkpoint['model_state_dict'] 118 | net.load_state_dict(state_dict,strict=True) 119 | net.module.queue = checkpoint['queue'] 120 | args.best_acc = checkpoint['best_acc'] 121 | resume_epoch = checkpoint['epoch'] 122 | 123 | elif args.pretrained_weights_path!='': 124 | checkpoint = torch.load(osp.join(args.pretrained_weights_path)) 125 | net.load_state_dict(checkpoint['model_state_dict'],strict=True) 126 | net.module.queue= cp.load(open(args.feature_file,'rb')) 127 | resume_epoch = -1 128 | print('=== resumed from checkpoint:', args.pretrained_weights_path,'===') 129 | 130 | 131 | ## Set Up Dataloader 132 | 133 | transform = transforms.Compose([ 134 | A.RandomSizedCrop(size=args.img_dim, consistent=True, clip_len=args.clip_len, h_ratio=0.6,p=0.8), 135 | A.RandomHorizontalFlip(consistent=True, clip_len=args.clip_len), 136 | A.ColorJitter(brightness=0.4, contrast=0.7, saturation=0.7, hue=0.25, 137 | p=1.0, consistent=False, clip_len=args.clip_len), 138 | A.ToTensor(), 139 | A.Normalize(args.dataset)]) 140 | transform_test = transforms.Compose([ 141 | A.CenterCrop(size=args.img_dim), 142 | A.ToTensor(), 143 | A.Normalize(args.dataset)]) 144 | 145 | trainset = TQN_dataloader(args, 146 | transform=transform, mode='train', 147 | SUFB = True) 148 | testset = TQN_dataloader(args,mode='val', 149 | transform=transform_test, 150 | SUFB = True) 151 | 152 | 153 | train_loader = data.DataLoader(trainset, args.batch_size,num_workers=args.num_workers, 154 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=True, 155 | drop_last=True,collate_fn = SUFB_collate, sampler= None) 156 | test_loader = data.DataLoader(testset, args.batch_size,num_workers=args.num_workers, 157 | pin_memory=True, worker_init_fn=worker_init_fn, shuffle=False, 158 | collate_fn = SUFB_collate,drop_last=True,sampler = None) 159 | 160 | ## Set Up Optimizer 161 | 162 | parameters = net.parameters() 163 | params = [] 164 | print('=> [optimizer] finetune TFM with smaller lr') 165 | for name, param in net.named_parameters(): 166 | if ('attention' in name or 'decoder' in name) and int(args.resume)<10: 167 | params.append({'params': param, 'lr':args.lr/10}) 168 | else: 169 | params.append({'params': param, 'lr':args.lr}) 170 | 171 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 172 | 173 | 174 | ## Set Up Tensorboard 175 | 176 | writer_val = SummaryWriter(logdir=osp.join(args.tbx_dir,'val')) 177 | writer_train = SummaryWriter(logdir=osp.join(args.tbx_dir, 'train')) 178 | 179 | args.val_plotter = PlotterThread(writer_val) 180 | args.train_plotter = PlotterThread(writer_train) 181 | 182 | 183 | ## Start Training 184 | 185 | cudnn.benchmark = True 186 | net.train() 187 | 188 | for epoch in range(args.max_epoches): 189 | if epoch <= resume_epoch: 190 | continue 191 | adjust_learning_rate(args,optimizer, epoch) 192 | train_one_epoch(args,epoch,net,optimizer,trainset,train_loader,SUFB=True) 193 | 194 | if epoch % args.eval_epoch == 0: 195 | eval_one_epoch(args,epoch,net,testset,test_loader,SUFB=True) 196 | 197 | 198 | 199 | 200 | def adjust_learning_rate(args,optimizer, epoch): 201 | """Sets the learning rate to the initial LR decayed by 10 """ 202 | epoch = epoch - 1 203 | decay = 0.1 ** (sum(epoch >= np.array(args.lr_steps))) 204 | lr = args.lr * decay 205 | decay = args.weight_decay 206 | print('current epoch:',epoch,'lr:',lr) 207 | if epoch >=10: 208 | for param_group in optimizer.param_groups: 209 | param_group['lr'] = lr 210 | param_group['weight_decay'] = decay 211 | 212 | 213 | 214 | if __name__ == '__main__': 215 | main() 216 | 217 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Lextal/pspnet-pytorch 2 | import random 3 | import numbers 4 | import math 5 | import collections 6 | import torchvision 7 | import statistics 8 | from scipy.special import softmax 9 | from torchvision import transforms 10 | import torchvision.transforms.functional as F 11 | from collections import Counter 12 | from itertools import groupby 13 | 14 | from PIL import ImageOps, Image 15 | import numpy as np 16 | import pickle as cp 17 | import os.path as osp 18 | 19 | class Padding: 20 | def __init__(self, pad): 21 | self.pad = pad 22 | 23 | def __call__(self, img): 24 | return ImageOps.expand(img, border=self.pad, fill=0) 25 | 26 | 27 | class Scale: 28 | def __init__(self, size, interpolation=Image.BICUBIC): 29 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 30 | self.size = size 31 | self.interpolation = interpolation 32 | 33 | def __call__(self, imgmap): 34 | # assert len(imgmap) > 1 # list of images, last one is target (for segmentation tasks only) 35 | img1 = imgmap[0] 36 | if isinstance(self.size, int): 37 | w, h = img1.size 38 | if (w <= h and w == self.size) or (h <= w and h == self.size): 39 | return imgmap 40 | if w < h: 41 | ow = self.size 42 | oh = int(self.size * h / w) 43 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 44 | else: 45 | oh = self.size 46 | ow = int(self.size * w / h) 47 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 48 | else: 49 | return [i.resize(self.size, self.interpolation) for i in imgmap] 50 | 51 | 52 | class CenterCrop: 53 | def __init__(self, size, consistent=True): 54 | if isinstance(size, numbers.Number): 55 | self.size = (int(size), int(size)) 56 | else: 57 | self.size = size 58 | 59 | def __call__(self, imgmap): 60 | img1 = imgmap[0] 61 | w, h = img1.size 62 | # imgmap = [i.resize((int(w*1.6),int(h*1.6))) for i in imgmap] 63 | # w, h = imgmap[0].size 64 | th, tw = self.size 65 | x1 = int(round((w - tw) / 2.)) 66 | y1 = int(round((h - th) / 2.)) 67 | 68 | 69 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 70 | 71 | 72 | class RandomSizedCrop: 73 | def __init__(self, size, interpolation=Image.BICUBIC, consistent=True, p=1.0, clip_len=0, h_ratio=0.7): 74 | self.size = size 75 | self.interpolation = interpolation 76 | self.consistent = consistent 77 | self.threshold = p 78 | self.clip_len = clip_len 79 | self.h_ratio = h_ratio 80 | 81 | def __call__(self, imgmap): 82 | img1 = imgmap[0] 83 | if random.random() < self.threshold: # do RandomSizedCrop 84 | for attempt in range(10): 85 | ori_w,ori_h = img1.size 86 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 87 | h = int(random.uniform(self.h_ratio, 1.0) * ori_h) 88 | w = int(h*aspect_ratio) 89 | if self.consistent: 90 | # if random.random() < 0.5: 91 | # w, h = h, w 92 | if w <= img1.size[0] and h <= img1.size[1]: 93 | mid_x = int(img1.size[0]//2) 94 | mid_h = int(img1.size[1]//2) 95 | 96 | # x1 = random.randint(int(mid_x-ori_w*0.15),int(mid_x+ori_w*0.15)) - w//2 97 | x1 = random.randint(0, img1.size[0] - w) 98 | y1 = random.randint(0, img1.size[1] - h) 99 | 100 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 101 | for i in imgmap: assert(i.size == (w, h)) 102 | 103 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 104 | else: 105 | result = [] 106 | 107 | if random.random() < 0.5: 108 | w, h = h, w 109 | 110 | for idx, i in enumerate(imgmap): 111 | if w <= img1.size[0] and h <= img1.size[1]: 112 | if idx % self.clip_len == 0: 113 | mid_x = int(img1.size[0]//2) 114 | 115 | x1 = random.randint(int(mid_x-ori_w*0.15),int(mid_x+ori_w*0.15)) - w//2 116 | y1 = random.randint(0, img1.size[1] - h) 117 | 118 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 119 | assert(result[-1].size == (w, h)) 120 | else: 121 | result.append(i) 122 | 123 | assert len(result) == len(imgmap) 124 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 125 | 126 | # Fallback 127 | scale = Scale(self.size, interpolation=self.interpolation) 128 | crop = CenterCrop(self.size) 129 | return crop(scale(imgmap)) 130 | else: #don't do RandomSizedCrop, do CenterCrop 131 | crop = CenterCrop(self.size) 132 | return crop(imgmap) 133 | 134 | 135 | class RandomHorizontalFlip: 136 | def __init__(self, consistent=True, command=None, clip_len=0): 137 | self.consistent = consistent 138 | if command == 'left': 139 | self.threshold = 0 140 | elif command == 'right': 141 | self.threshold = 1 142 | else: 143 | self.threshold = 0.5 144 | self.clip_len = clip_len 145 | def __call__(self, imgmap): 146 | if self.consistent: 147 | if random.random() < self.threshold: 148 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 149 | else: 150 | return imgmap 151 | else: 152 | result = [] 153 | for idx, i in enumerate(imgmap): 154 | if idx % self.clip_len == 0: th = random.random() 155 | if th < self.threshold: 156 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 157 | else: 158 | result.append(i) 159 | assert len(result) == len(imgmap) 160 | return result 161 | 162 | 163 | 164 | 165 | class ColorJitter(object): 166 | """Randomly change the brightness, contrast and saturation of an image. 167 | Args: 168 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 169 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 170 | or the given [min, max]. Should be non negative numbers. 171 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 172 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 173 | or the given [min, max]. Should be non negative numbers. 174 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 175 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 176 | or the given [min, max]. Should be non negative numbers. 177 | hue (float or tuple of float (min, max)): How much to jitter hue. 178 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 179 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 180 | """ 181 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0, clip_len=0): 182 | self.brightness = self._check_input(brightness, 'brightness') 183 | self.contrast = self._check_input(contrast, 'contrast') 184 | self.saturation = self._check_input(saturation, 'saturation') 185 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 186 | clip_first_on_zero=False) 187 | self.consistent = consistent 188 | self.threshold = p 189 | self.clip_len = clip_len 190 | 191 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 192 | if isinstance(value, numbers.Number): 193 | if value < 0: 194 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 195 | value = [center - value, center + value] 196 | if clip_first_on_zero: 197 | value[0] = max(value[0], 0) 198 | elif isinstance(value, (tuple, list)) and len(value) == 2: 199 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 200 | raise ValueError("{} values should be between {}".format(name, bound)) 201 | else: 202 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 203 | 204 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 205 | # or (0., 0.) for hue, do nothing 206 | if value[0] == value[1] == center: 207 | value = None 208 | return value 209 | 210 | @staticmethod 211 | def get_params(brightness, contrast, saturation, hue): 212 | """Get a randomized transform to be applied on image. 213 | Arguments are same as that of __init__. 214 | Returns: 215 | Transform which randomly adjusts brightness, contrast and 216 | saturation in a random order. 217 | """ 218 | transforms = [] 219 | 220 | if brightness is not None: 221 | brightness_factor = random.uniform(brightness[0], brightness[1]) 222 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 223 | 224 | if contrast is not None: 225 | contrast_factor = random.uniform(contrast[0], contrast[1]) 226 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 227 | 228 | if saturation is not None: 229 | saturation_factor = random.uniform(saturation[0], saturation[1]) 230 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 231 | 232 | if hue is not None: 233 | hue_factor = random.uniform(hue[0], hue[1]) 234 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 235 | 236 | random.shuffle(transforms) 237 | transform = torchvision.transforms.Compose(transforms) 238 | 239 | 240 | return transform 241 | 242 | def __call__(self, imgmap): 243 | if random.random() < self.threshold: # do ColorJitter 244 | if self.consistent: 245 | transform = self.get_params(self.brightness, self.contrast, 246 | self.saturation, self.hue) 247 | return [transform(i) for i in imgmap] 248 | else: 249 | if self.clip_len == 0: 250 | return [self.get_params(self.brightness, self.contrast, self.saturation, self.hue)(img) for img in imgmap] 251 | else: 252 | result = [] 253 | for idx, img in enumerate(imgmap): 254 | if idx % self.clip_len == 0: 255 | transform = self.get_params(self.brightness, self.contrast, 256 | self.saturation, self.hue) 257 | result.append(transform(img)) 258 | return result 259 | 260 | else: # don't do ColorJitter, do nothing 261 | return imgmap 262 | 263 | def __repr__(self): 264 | format_string = self.__class__.__name__ + '(' 265 | format_string += 'brightness={0}'.format(self.brightness) 266 | format_string += ', contrast={0}'.format(self.contrast) 267 | format_string += ', saturation={0}'.format(self.saturation) 268 | format_string += ', hue={0})'.format(self.hue) 269 | return format_string 270 | 271 | 272 | 273 | class ToTensor: 274 | def __call__(self, imgmap): 275 | totensor = transforms.ToTensor() 276 | return [totensor(i) for i in imgmap] 277 | 278 | class ToPIL: 279 | def __call__(self, imgmap): 280 | topil = transforms.ToPILImage() 281 | return [topil(i) for i in imgmap] 282 | 283 | class Normalize: 284 | def __init__(self, dataset=None,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 285 | 286 | if 'diving' in dataset: 287 | self.mean = [0.3381, 0.5108, 0.5785] 288 | self.std = [0.2206, 0.2309, 0.2615] 289 | else: 290 | self.mean = mean 291 | self.std = std 292 | 293 | def __call__(self, imgmap): 294 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 295 | return [normalize(i) for i in imgmap] 296 | 297 | 298 | def pil_loader(path): 299 | with open(path, 'rb') as f: 300 | with Image.open(f) as img: 301 | return img.convert('RGB') -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from collections import deque 4 | from threading import Thread 5 | from queue import Queue 6 | 7 | class Logger(object): 8 | '''write something to txt file''' 9 | def __init__(self, path): 10 | self.birth_time = datetime.now() 11 | filepath = os.path.join(path, self.birth_time.strftime('%Y-%m-%d-%H:%M:%S')+'.log') 12 | self.filepath = filepath 13 | with open(filepath, 'a') as f: 14 | f.write(self.birth_time.strftime('%Y-%m-%d %H:%M:%S')+'\n') 15 | 16 | def log(self, string): 17 | with open(self.filepath, 'a') as f: 18 | time_stamp = datetime.now() - self.birth_time 19 | f.write(strfdelta(time_stamp,"{d}-{h:02d}:{m:02d}:{s:02d}")+'\t'+string+'\n') 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | def __init__(self, name='null', fmt=':.4f'): 25 | self.name = name 26 | self.fmt = fmt 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | self.local_history = deque([]) 35 | self.local_avg = 0 36 | self.history = [] 37 | self.dict = {} # save all data values here 38 | self.save_dict = {} # save mean and std here, for summary table 39 | 40 | def update(self, val, n=1, history=0, step=5): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | if n == 0: return 45 | self.avg = self.sum / self.count 46 | if history: 47 | self.history.append(val) 48 | if step > 0: 49 | self.local_history.append(val) 50 | if len(self.local_history) > step: 51 | self.local_history.popleft() 52 | self.local_avg = np.average(self.local_history) 53 | 54 | 55 | def dict_update(self, val, key): 56 | if key in self.dict.keys(): 57 | self.dict[key].append(val) 58 | else: 59 | self.dict[key] = [val] 60 | 61 | def print_dict(self, title='IoU', save_data=False): 62 | """Print summary, clear self.dict and save mean+std in self.save_dict""" 63 | total = [] 64 | for key in self.dict.keys(): 65 | val = self.dict[key] 66 | avg_val = np.average(val) 67 | len_val = len(val) 68 | std_val = np.std(val) 69 | 70 | if key in self.save_dict.keys(): 71 | self.save_dict[key].append([avg_val, std_val]) 72 | else: 73 | self.save_dict[key] = [[avg_val, std_val]] 74 | 75 | print('Activity:%s, mean %s is %0.4f, std %s is %0.4f, length of data is %d' \ 76 | % (key, title, avg_val, title, std_val, len_val)) 77 | 78 | total.extend(val) 79 | 80 | self.dict = {} 81 | avg_total = np.average(total) 82 | len_total = len(total) 83 | std_total = np.std(total) 84 | print('\nOverall: mean %s is %0.4f, std %s is %0.4f, length of data is %d \n' \ 85 | % (title, avg_total, title, std_total, len_total)) 86 | 87 | if save_data: 88 | print('Save %s pickle file' % title) 89 | with open('img/%s.pickle' % title, 'wb') as f: 90 | pickle.dump(self.save_dict, f) 91 | 92 | def __len__(self): 93 | return self.count 94 | 95 | def __str__(self): 96 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 97 | return fmtstr.format(**self.__dict__) 98 | 99 | 100 | 101 | 102 | class PlotterThread(): 103 | def __init__(self, writer): 104 | self.writer = writer 105 | self.task_queue = Queue(maxsize=0) 106 | worker = Thread(target=self.do_work, args=(self.task_queue,)) 107 | worker.setDaemon(True) 108 | worker.start() 109 | 110 | def do_work(self, q): 111 | while True: 112 | content = q.get() 113 | if content[-1] == 'image': 114 | self.writer.add_image(*content[:-1]) 115 | elif content[-1] == 'scalar': 116 | self.writer.add_scalar(*content[:-1]) 117 | elif content[-1] == 'gif': 118 | self.writer.add_video(*content[:-1]) 119 | else: 120 | raise ValueError 121 | q.task_done() 122 | 123 | def add_data(self, name, value, step, data_type='scalar'): 124 | self.task_queue.put([name, value, step, data_type]) 125 | 126 | def __len__(self): 127 | return self.task_queue.qsize() 128 | 129 | 130 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path as osp 4 | 5 | def tfm_mask(seg_per_video,temporal_mutliplier=1): 6 | """ 7 | Attention mask for padded sequence in the Transformer 8 | True: not allowed to attend to 9 | """ 10 | B = len(seg_per_video) 11 | L = max(seg_per_video) * temporal_mutliplier 12 | mask = torch.ones(B,L,dtype=torch.bool) 13 | for ind,l in enumerate(seg_per_video): 14 | mask[ind,:(l*temporal_mutliplier)] = False 15 | 16 | return mask 17 | 18 | 19 | 20 | def calc_topk_accuracy(output, target, topk=(1,)): 21 | """ 22 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 23 | Given predicted and ground truth labels, 24 | calculate top-k accuracies. 25 | """ 26 | maxk = max(topk) 27 | batch_size = target.size(0) 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0) 36 | res.append(correct_k.mul_(1 / batch_size)) 37 | return res 38 | 39 | 40 | 41 | def multihead_acc(preds,clabel,target,vocab,\ 42 | Q=4,return_probs=False): 43 | """ 44 | Args: 45 | preds: Predicted logits 46 | clabel: Class labels, 47 | List, [batch_size] 48 | target: Ground Truth attribute labels 49 | List, [batch_size,num_queries] 50 | vocab: The mapping between class index and attributes. 51 | List, [num_classes,num_queries] 52 | Q: Number of queries, Int 53 | 54 | Output: 55 | prob_acc: match predicted attibutes to ground-truth attibutes of N classes, 56 | class with the highest similarity is the predicted class. 57 | """ 58 | 59 | # reshape the preds to (B,num_heads,num_classes) 60 | if len(preds.shape)==2: 61 | BQ,C = preds.shape 62 | B = BQ//Q 63 | preds = preds.view(-1,Q,C) 64 | elif len(preds.shape)==3: 65 | B,Q,C = preds.shape 66 | 67 | target = target.view(-1,Q) 68 | vocab_onehot = one_hot(vocab,C) 69 | 70 | cls_logits =torch.einsum('bhc,ahc->ba', preds, vocab_onehot.cuda()) 71 | cls_pred = torch.argmax(cls_logits,dim=-1) 72 | prob_acc = (cls_pred == clabel).sum()*1.0 /B 73 | 74 | if return_probs: 75 | return prob_acc,cls_logits 76 | else: 77 | return prob_acc 78 | 79 | 80 | 81 | def one_hot(indices,depth): 82 | """ 83 | make one hot vectors from indices 84 | """ 85 | y = indices.unsqueeze(-1).long() 86 | y_onehot = torch.zeros(*indices.shape,depth) 87 | if indices.is_cuda: 88 | y_onehot = y_onehot.cuda() 89 | return y_onehot.scatter(-1,y,1) 90 | 91 | 92 | 93 | def make_dirs(args): 94 | 95 | if osp.exists(args.save_folder) == False: 96 | os.mkdir(args.save_folder) 97 | args.save_folder = osp.join(args.save_folder ,args.name) 98 | if osp.exists(args.save_folder) == False: 99 | os.mkdir(args.save_folder) 100 | 101 | args.tbx_dir =osp.join(args.tbx_folder,args.name) 102 | if osp.exists(args.tbx_folder) == False: 103 | os.mkdir(args.tbx_folder) 104 | 105 | if osp.exists(args.tbx_dir) == False: 106 | os.mkdir(args.tbx_dir) 107 | 108 | result_dir = osp.join(args.tbx_dir,'results') 109 | if osp.exists(result_dir) == False: 110 | os.mkdir(result_dir) 111 | 112 | 113 | 114 | def batch_denorm(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], channel=1): 115 | """ 116 | De-normalization the images for viusalization 117 | """ 118 | shape = [1]*tensor.dim(); shape[channel] = 3 119 | dtype = tensor.dtype 120 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device).view(shape) 121 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device).view(shape) 122 | output = tensor.mul(std).add(mean) 123 | return output 124 | --------------------------------------------------------------------------------