├── README.md ├── data ├── ava_dataset.py └── collate_batch.py ├── images ├── graph.jpg └── graph2.PNG ├── models ├── attention_layer.py ├── gat.py └── model.py ├── test.py ├── train.py └── utils ├── boxlist_ops.py └── checkpoints.py /README.md: -------------------------------------------------------------------------------- 1 | # STAGE: Spatio-Temporal Attention on Graph Entities 2 | This repository contains the train and test code for the paper _[STAGE: Spatio-Temporal Attention on Graph Entities for Video Action Detection](https://arxiv.org/abs/1912.04316)_ 3 | 4 |

5 | STAGE 6 |

7 | 8 | ## Requirements 9 | The required Python packages are: 10 | * torch>=1.0.0 11 | * h5py>=2.8.0 12 | * tensorboardX>=1.6 13 | 14 | ## Features 15 | In order to train and test the module, you need pre-computed actors and objects features coming from a pre-trained backbone on the [AVA dataset](https://research.google.com/ava/). Features must be organized in h5py files as follows: 16 | 17 | **Actors features** 18 | 19 | actors_features_dir 20 | | 21 | |-> 22 | | |-> .h5 23 | | |-> .h5 24 | | | 25 | | 26 | |-> 27 | | |-> .h5 28 | | |-> .h5 29 | | | 30 | | 31 | 32 | Each .h5 file should contain the following data: 33 | * "features" -> a torch tensor with shape (num_actors, feature_size, t, h, w) containing actors features 34 | * "boxes" -> a torch tensor with shape (num_actors, 4) containing bounding boxes coordinates for each actor 35 | * "labels" -> a torch tensor with shape (num_actors, 81) containing ones and zeros for performed/not performed actions 36 | 37 | **Objects features** 38 | 39 | objects_features.h5 40 | 41 | The file should contain the following data: 42 | * "\_\_features" -> a torch tensor with shape (num_objects, feature_size) containing objects features 43 | * "\_\_boxes" -> a torch tensor with shape (num_objects, 4) containing bounding boxes coordinates for each object 44 | * "\_\_cls_prob" -> a torch tensor with shape (num_objects, num_classes) containing objects probabilities for each class 45 | 46 | For example, the objects features of the clipID '-5KQ66BBWC4' at timestamp '902' will be in 47 | 48 | objects_features["5KQ66BBWC4_902_features"] 49 | 50 | I3D actors features are available at the following links: 51 | * [[I3D_actors_train]](https://drive.google.com/open?id=1RlciPLrEQcY0uYecS_cEWydrpvWg9DZv) 52 | * [[I3D_actors_val]](https://drive.google.com/open?id=1HCjezdcr2BkVUIEJgzBKPYSYLA0a9vxw) 53 | 54 | Each tar.gz contains a directory, which corresponds to the "actors_features_dir" root. 55 | 56 | Faster-RCNN objects features are available at the following links: 57 | * [[Faster-RCNN_objects_train]](https://drive.google.com/file/d/13PrXvAR-Rw9MaTAJA5hInpJG4V_FLNuB/view?usp=sharing) 58 | * [[Faster-RCNN_objects_val]](https://drive.google.com/open?id=17_9NkM0kB_j0YEersD6y5WRPcKL6fiLp) 59 | 60 | Each tar.gz contains an h5py file, which corresponds to the "objects_features.h5" file. 61 | 62 | The size of all the features is ~90 GB. 63 | 64 | **SlowFast features** 65 | 66 | You can find SlowFast features at the following links: 67 | * [[slowfast_ava2.1_32x2_features_train]](https://drive.google.com/file/d/1DW0b3Cc4d64P5Ir40cxpquGwYXkA_g-P/view?usp=sharing) 68 | * [[slowfast_ava2.1_32x2_features_val]](https://drive.google.com/file/d/1GbCLQ5jK8tk5FBj_DCKouQwyEApkkUfk/view?usp=sharing) 69 | * [[slowfast_ava2.2_32x2_features_train]](https://drive.google.com/file/d/1dV1F1wYDBl4M8BRp8_uKlGG4Vszzv9dH/view?usp=sharing) 70 | * [[slowfast_ava2.2_32x2_features_val]](https://drive.google.com/file/d/1C9m0DvE0rEyrmG36RggUCLlZOd2AqbVO/view?usp=sharing) 71 | * [[slowfast_ava2.2_64x2_features_train]](https://drive.google.com/file/d/1D-W7mJsWAt843GA0IwLaAD7qOpxIQESC/view?usp=sharing) 72 | * [[slowfast_ava2.2_64x2_features_val]](https://drive.google.com/file/d/1_Wd89_kQYtwL5IBmj7skJB0uVEUaW5V8/view?usp=sharing) 73 | 74 | **Note**: These features are organized differently from I3D ones (you should write a specific dataloader or modify the provided one): 75 | each 'h5py' file is a dictionary, each containing keys in the format "\_\". For each key, the corresponding value is another dictionary with keys "boxes", "features", "labels", containing actors' boxes coorindates, features extracted from the last SlowFast layer before classification and ground truth labels for that specific clip. 76 | 77 | 78 | ## Training 79 | 80 | Run `python train.py` using the following arguments: 81 | 82 | | Argument | Value | 83 | |------|------| 84 | | `--actors_dir` | Path to the train actors_features_dir | 85 | | `--objects_file ` | Path to the train objects_features.h5 file | 86 | | `--output_dir ` | Path to the directory where checkpoints will be stored | 87 | | `--log_tensorboard_dir ` | Path to the directory where tensorboard logs will be stored | 88 | | `--batch_size ` | The batch size. Must be > 1 to allow temporal connections | 89 | | `--n_workers ` | The number of workers | 90 | | `--lr ` | The learning rate | 91 | 92 | For example, use: 93 | ``` 94 | python train.py --actors_dir "./actors_features_dir" --objects_file "./objects_features.h5" --output_dir "./out_checkpoints" --log_tensorboard_dir "./out_tensorboard" --batch_size 6 --n_workers 8 --lr 0.0000625 95 | ``` 96 | 97 | ## Testing 98 | 99 | Run `python test.py` using the following arguments: 100 | 101 | | Argument | Value | 102 | |------|------| 103 | | `--actors_dir` | Path to the val actors_features_dir | 104 | | `--objects_file ` | Path to the val objects_features.h5 file | 105 | | `--output_dir ` | Path to the directory where the checkpoint to load is stored | 106 | | `--batch_size ` | The batch size. Must be > 1 to allow temporal connections | 107 | | `--n_workers ` | The number of workers | 108 | 109 | A "results.csv" file will be created under the "output_dir" directory, which should be used for evaluation as explained [here](https://research.google.com/ava/download.html) 110 | 111 | For example, use: 112 | ``` 113 | python test.py --actors_dir "./actors_features_dir" --objects_file "./objects_features.h5" --output_dir "./out_checkpoints" --batch_size 6 --n_workers 8 114 | ``` 115 | 116 | -------------------------------------------------------------------------------- /data/ava_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data_utl 3 | import h5py 4 | import os 5 | import random 6 | 7 | 8 | def sort_function(filename): 9 | return (filename.split('/')[-2], int(filename.split('/')[-1].split('.')[-2])) 10 | 11 | class AVADataset(data_utl.Dataset): 12 | 13 | def __init__(self, split='train', videodir='./train_features_I3D', objectsfile='./ava_objects_fasterrcnn.hdf5'): 14 | self.split = split 15 | self.objectsfile = objectsfile 16 | 17 | self.filenames = [] 18 | for dirname in os.listdir(videodir): 19 | for filename in os.listdir(os.path.join(videodir, dirname)): 20 | self.filenames.append(os.path.join(videodir, dirname, filename)) 21 | 22 | if self.split == "val": 23 | self.filenames.sort(key=sort_function) 24 | 25 | def __getitem__(self, index): 26 | filename = self.filenames[index] 27 | clip_id = filename.split('/')[-2] 28 | timestamp = filename.split('/')[-1].split('.')[0] 29 | 30 | hf_actors = h5py.File(filename, 'r') 31 | actors_features = torch.from_numpy(hf_actors.get("features").value) 32 | actors_labels = torch.from_numpy(hf_actors.get('labels').value) 33 | actors_boxes = torch.from_numpy(hf_actors.get('boxes').value) 34 | 35 | hf_objects = h5py.File(self.objectsfile, 'r') 36 | objects_features = torch.from_numpy(hf_objects.get(clip_id + '_' + timestamp.lstrip("0") + '_' + 'features').value) 37 | objects_boxes = torch.from_numpy(hf_objects.get(clip_id + '_' + timestamp.lstrip("0") + '_' + 'boxes').value) 38 | 39 | return actors_features, actors_labels, actors_boxes, [(clip_id, timestamp) for _ in range(actors_features.shape[0])], objects_features, objects_boxes, [(clip_id, timestamp) for _ in range(objects_features.shape[0])] 40 | 41 | def __len__(self): 42 | return len(self.filenames) 43 | 44 | def rotate(self, n): 45 | self.filenames = self.filenames[n:] + self.filenames[:n] 46 | 47 | def shuffle_filename_blocks(self, block_size, epoch): 48 | self.filenames.sort(key=sort_function) 49 | self.rotate(int(block_size/4)*(epoch-1)) 50 | self.filenames = [self.filenames[i:i+block_size] for i in range(0,len(self.filenames),block_size)] 51 | random.shuffle(self.filenames) 52 | self.filenames[:] = [b for bs in self.filenames for b in bs] 53 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import boxlist_ops 3 | 4 | 5 | class BatchCollator(object): 6 | 7 | def __call__(self, batch): 8 | transposed_batch = list(zip(*batch)) 9 | actors_features = torch.cat(transposed_batch[0], dim=0) 10 | actors_labels = torch.cat(transposed_batch[1], dim=0) 11 | actors_boxes = torch.cat(transposed_batch[2], dim=0) 12 | actors_filenames = sum(transposed_batch[3], []) 13 | 14 | objects_features = torch.cat(transposed_batch[4], dim=0) 15 | objects_boxes = torch.cat(transposed_batch[5], dim=0) 16 | objects_filenames = sum(transposed_batch[6], []) 17 | 18 | num_actor_proposals = actors_boxes.shape[0] 19 | num_object_proposals = objects_boxes.shape[0] 20 | 21 | adj = torch.zeros((num_actor_proposals + num_object_proposals, num_actor_proposals + num_object_proposals)) 22 | 23 | cur_actors = 0 24 | cur_objects = 0 25 | 26 | tau=1 27 | # populate the adj matrix in the actor-actor and actor-object sections 28 | for i in range(len(transposed_batch[0])): 29 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0],cur_actors:cur_actors + transposed_batch[0][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i], tau=tau) 30 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i], tau=tau) 31 | if i==0: 32 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors + transposed_batch[0][i].shape[0]:cur_actors + transposed_batch[0][i].shape[0] + transposed_batch[0][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i+1], tau=tau) 33 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i+1], tau=tau) 34 | elif i == len(transposed_batch[3]) - 1: 35 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors - transposed_batch[0][i-1].shape[0]:cur_actors] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i-1], tau=tau) 36 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i-1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i-1], tau=tau) 37 | else: 38 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors + transposed_batch[0][i].shape[0]:cur_actors + transposed_batch[0][i].shape[0] + transposed_batch[0][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i + 1], tau=tau) 39 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors - transposed_batch[0][i - 1].shape[0]:cur_actors] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i - 1], tau=tau) 40 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i + 1], tau=tau) 41 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i - 1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i - 1], tau=tau) 42 | 43 | cur_actors += transposed_batch[0][i].shape[0] 44 | cur_objects += transposed_batch[4][i].shape[0] 45 | 46 | # populate the adj matrix in the object-actor section 47 | adj[num_actor_proposals:, :num_actor_proposals] = torch.t(adj[:num_actor_proposals, num_actor_proposals:]) 48 | cur_objects = 0 49 | 50 | # populate the adj matrix in the object-object section 51 | for i in range(len(transposed_batch[4])): 52 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i], tau=tau) 53 | if i==0: 54 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i+1], tau=tau) 55 | elif i == len(transposed_batch[3]) - 1: 56 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i-1].shape[0]:num_actor_proposals+cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i-1], tau=tau) 57 | else: 58 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i + 1], tau=tau) 59 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i - 1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i - 1], tau=tau) 60 | 61 | cur_objects += transposed_batch[4][i].shape[0] 62 | 63 | return actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj -------------------------------------------------------------------------------- /images/graph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/STAGE_action_detection/a5d76114f47c103deef79ac6056b1794961fb58c/images/graph.jpg -------------------------------------------------------------------------------- /images/graph2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/STAGE_action_detection/a5d76114f47c103deef79ac6056b1794961fb58c/images/graph2.PNG -------------------------------------------------------------------------------- /models/attention_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GraphAttentionLayer(nn.Module): 7 | def __init__(self, in_features, out_features, dropout, alpha): 8 | super(GraphAttentionLayer, self).__init__() 9 | self.dropout = dropout 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.alpha = alpha 13 | 14 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 15 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 16 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) 17 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 18 | 19 | self.leakyrelu = nn.LeakyReLU(self.alpha) 20 | 21 | def forward(self, input, adj): 22 | h = torch.mm(input, self.W) 23 | N = h.size()[0] 24 | 25 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 26 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 27 | 28 | zero_vec = -9e15 * torch.ones_like(e) 29 | attention = torch.where(adj > 0, e*adj, zero_vec) 30 | tau = 1 31 | attention = F.softmax(attention / tau, dim=1) 32 | h_prime = torch.matmul(attention, h) 33 | h_prime = F.dropout(h_prime, self.dropout, training=self.training) 34 | 35 | return F.elu(h_prime) 36 | 37 | -------------------------------------------------------------------------------- /models/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.attention_layer import GraphAttentionLayer 5 | 6 | 7 | class GAT(nn.Module): 8 | def __init__(self, nfeat, nhid, dropout, alpha, nheads): 9 | super(GAT, self).__init__() 10 | self.dropout = dropout 11 | 12 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha) for _ in range(nheads)] 13 | for i, attention in enumerate(self.attentions): 14 | self.add_module('attention_{}'.format(i), attention) 15 | 16 | def forward(self, x, adj): 17 | x = F.dropout(x, self.dropout, training=self.training) 18 | x = torch.cat([v(x, adj) for k, v in self._modules.items() if k.startswith("attention")], dim=1) 19 | return x 20 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.gat import GAT 3 | from torch import nn 4 | 5 | 6 | class Stage(torch.nn.Module): 7 | 8 | def __init__(self, num_classes, actors_features_size, objects_features_size, n_heads): 9 | super(Stage, self).__init__() 10 | 11 | self.num_classes = num_classes 12 | self.actors_features_size = actors_features_size 13 | self.objects_features_size = objects_features_size 14 | self.n_heads = n_heads 15 | 16 | if self.objects_features_size > self.actors_features_size: 17 | self.obj_reducer = nn.Linear(in_features=self.objects_features_size, out_features=self.actors_features_size) 18 | 19 | self.gat1 = GAT(self.actors_features_size + 4, int(self.actors_features_size/self.n_heads) + int(4/self.n_heads), 0.5, 0.2, self.n_heads) #we add 4 because we are going to add h,w,xc,yc to the channel axis 20 | self.gat_fc = nn.Linear(in_features=self.actors_features_size +4, out_features=self.actors_features_size +4) 21 | self.l_norm = nn.LayerNorm(self.actors_features_size +4) 22 | 23 | self.gat2 = GAT(self.actors_features_size + 4, int(self.actors_features_size/self.n_heads) + int(4/self.n_heads), 0.5, 0.2, self.n_heads) 24 | self.gat_fc2 = nn.Linear(in_features=self.actors_features_size +4, out_features=self.actors_features_size +4) 25 | self.l_norm2 = nn.LayerNorm(self.actors_features_size +4) 26 | 27 | self.logits = nn.Linear(in_features=self.actors_features_size +4, out_features=self.num_classes) 28 | 29 | 30 | def forward(self, actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj): 31 | #compute h, w, xc, yc for each actor/object 32 | actors_h = actors_boxes[:, 3] - actors_boxes[:, 1] 33 | objects_h = objects_boxes[:, 3] - objects_boxes[:, 1] 34 | actors_w = actors_boxes[:, 2] - actors_boxes[:, 0] 35 | objects_w = objects_boxes[:, 2] - objects_boxes[:, 0] 36 | actors_centers_x = ((actors_boxes[:, 2] - actors_boxes[:, 0]) / 2) + actors_boxes[:, 0] 37 | actors_centers_y = ((actors_boxes[:, 3] - actors_boxes[:, 1]) / 2) + actors_boxes[:, 1] 38 | objects_centers_x = ((objects_boxes[:, 2] - objects_boxes[:, 0]) / 2) + objects_boxes[:, 0] 39 | objects_centers_y = ((objects_boxes[:, 3] - objects_boxes[:, 1]) / 2) + objects_boxes[:, 1] 40 | 41 | with torch.no_grad(): 42 | actors_features = torch.mean(torch.mean(torch.mean(actors_features, dim=2), dim=-1), dim=-1) 43 | actors_features = torch.cat((actors_features, actors_h.unsqueeze(1), actors_w.unsqueeze(1), actors_centers_x.unsqueeze(1), actors_centers_y.unsqueeze(1)), dim=1) 44 | 45 | if self.objects_features_size > self.actors_features_size: 46 | objects_features = nn.functional.relu(self.obj_reducer(objects_features)) 47 | 48 | objects_features = torch.cat((objects_features, objects_h.unsqueeze(1), objects_w.unsqueeze(1), objects_centers_x.unsqueeze(1), objects_centers_y.unsqueeze(1)), dim=1) 49 | 50 | all_features = torch.cat((actors_features, objects_features), dim=0) 51 | 52 | gat_pred = self.gat1(all_features, adj) 53 | all_features = all_features + self.gat_fc(gat_pred) 54 | all_features = self.l_norm(all_features) 55 | 56 | gat_pred = self.gat2(all_features, adj) 57 | all_features = all_features + self.gat_fc2(gat_pred) 58 | all_features = self.l_norm2(all_features) 59 | 60 | pred = self.logits(all_features[:actors_features.shape[0], :]) 61 | 62 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, actors_labels) 63 | 64 | pred = torch.sigmoid(pred) 65 | return pred, loss 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from data.ava_dataset import AVADataset 2 | import torch 3 | from utils import checkpoints 4 | import os 5 | import argparse 6 | from data.collate_batch import BatchCollator 7 | from models.model import Stage 8 | import csv 9 | from tqdm import tqdm 10 | import sys 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--actors_dir", type=str, help="path to the directory containing actors features") 15 | parser.add_argument("--objects_dir", type=str, help="path to the file containing objects features") 16 | parser.add_argument("--output_dir", type=str, help="path to the directory where checkpoints will be stored") 17 | parser.add_argument("--batch_size", type=int) 18 | parser.add_argument("--n_workers", type=int) 19 | parser.add_argument("--num_classes", type=int, default=81) 20 | parser.add_argument("--actors_features_size", type=int, default=1024) 21 | parser.add_argument("--objects_features_size", type=int, default=2048) 22 | parser.add_argument("--n_heads", type=int, default=4, help="only 2 or 4 heads supported at the moment") 23 | 24 | 25 | def main(): 26 | args = parser.parse_args() 27 | 28 | num_classes = args.num_classes 29 | output_dir = args.output_dir 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | 33 | assert os.path.isdir(output_dir) 34 | 35 | ava_val = AVADataset(split='val', videodir=args.actors_dir, objectsfile=args.objects_dir) 36 | data_loader_val = torch.utils.data.DataLoader(ava_val, batch_size=args.batch_size, num_workers=args.n_workers, collate_fn=BatchCollator(), shuffle=False) 37 | 38 | model = Stage(num_classes, args.actors_features_size, args.objects_features_size, args.n_heads) 39 | 40 | model.eval() 41 | model.to(device) 42 | 43 | checkpoint = checkpoints.load(output_dir) 44 | if checkpoint: 45 | model.load_state_dict(checkpoint["model_state"], strict=False) 46 | else: 47 | sys.exit('No checkpoint found!') 48 | 49 | with torch.no_grad(): 50 | with open(os.path.join(output_dir, "results.csv"), mode='w') as csv_file: 51 | csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 52 | 53 | for iteration_val, (actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj) in enumerate(tqdm(data_loader_val), 1): 54 | actors_features = actors_features.to(device) 55 | actors_labels = actors_labels.to(device) 56 | actors_boxes = actors_boxes.to(device) 57 | objects_features = objects_features.to(device) 58 | objects_boxes = objects_boxes.to(device) 59 | adj = adj.to(device) 60 | 61 | pred, loss = model(actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj) 62 | 63 | for i, prop in enumerate(pred): 64 | classes = torch.nonzero(prop) 65 | for c in classes: 66 | if int(c) != 0: # do not consider background class 67 | csv_writer.writerow( 68 | [actors_filenames[i][0], str(int(actors_filenames[i][1])), 69 | str(actors_boxes[i, 0].item()), str(actors_boxes[i, 1].item()), 70 | str(actors_boxes[i, 2].item()), str(actors_boxes[i, 3].item()), 71 | int(c), 72 | prop[int(c)].item()]) 73 | csv_file.flush() 74 | 75 | csv_file.close() 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from data.ava_dataset import AVADataset 2 | import torch 3 | from utils import checkpoints 4 | from tensorboardX import SummaryWriter 5 | import os 6 | import errno 7 | import argparse 8 | from data.collate_batch import BatchCollator 9 | from models.model import Stage 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--actors_dir", type=str, help="path to the directory containing actors features") 14 | parser.add_argument("--objects_file", type=str, help="path to the file containing objects features") 15 | parser.add_argument("--output_dir", type=str, help="path to the directory where checkpoints will be stored") 16 | parser.add_argument("--log_tensorboard_dir", type=str, help="path to the directory where tensorboard logs will be stored") 17 | parser.add_argument("--batch_size", type=int) 18 | parser.add_argument("--n_workers", type=int) 19 | parser.add_argument("--lr", type=float) 20 | parser.add_argument("--num_classes", type=int, default=81) 21 | parser.add_argument("--actors_features_size", type=int, default=1024) 22 | parser.add_argument("--objects_features_size", type=int, default=2048) 23 | parser.add_argument("--n_heads", type=int, default=4, help="only 2 or 4 heads supported at the moment") 24 | parser.add_argument("--impose_lr", type=float, default=0.0, help="if not zero, will impose the specified learning rate to all the optimizer's parameters") 25 | parser.add_argument("--n_epochs", type=int, default=50) 26 | 27 | 28 | def main(): 29 | args = parser.parse_args() 30 | 31 | num_classes = args.num_classes 32 | start_epoch = 1 33 | start_iter = 1 34 | output_dir = args.output_dir 35 | tensorboard_dir = args.log_tensorboard_dir 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | try: 40 | os.makedirs(output_dir) 41 | except OSError as e: 42 | if e.errno != errno.EEXIST: 43 | raise 44 | 45 | ava_train = AVADataset(split='train', videodir=args.actors_dir, objectsfile=args.objects_file) 46 | data_loader_train = torch.utils.data.DataLoader(ava_train, batch_size=args.batch_size, num_workers=args.n_workers, collate_fn=BatchCollator(), shuffle=False) 47 | 48 | writer = SummaryWriter(tensorboard_dir) 49 | 50 | model = Stage(num_classes, args.actors_features_size, args.objects_features_size, args.n_heads) 51 | 52 | model.train() 53 | model.to(device) 54 | 55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 56 | 57 | checkpoint = checkpoints.load(output_dir) 58 | if checkpoint: 59 | model.load_state_dict(checkpoint["model_state"], strict=False) 60 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 61 | if args.impose_lr != 0.0: 62 | for param_group in optimizer.param_groups: 63 | param_group['lr'] = args.impose_lr 64 | start_epoch = checkpoint["epoch"] 65 | start_iter = checkpoint["iteration"] + 1 66 | else: 67 | print("No checkpoint found: initializing model from scratch") 68 | 69 | for current_epoch in range(start_epoch, args.n_epochs + 1): 70 | 71 | data_loader_train.dataset.shuffle_filename_blocks(args.batch_size, current_epoch) 72 | 73 | for iteration, (actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj) in enumerate(data_loader_train, start_iter): 74 | actors_features = actors_features.to(device) 75 | actors_labels = actors_labels.to(device) 76 | actors_boxes = actors_boxes.to(device) 77 | objects_features = objects_features.to(device) 78 | objects_boxes = objects_boxes.to(device) 79 | adj = adj.to(device) 80 | 81 | pred, loss = model(actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj) 82 | 83 | loss = torch.mean(loss) 84 | 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | if iteration % 20 == 0: 90 | writer.add_scalar('train/class_loss', loss, (current_epoch-1) * len(data_loader_train) + iteration) 91 | print("epoch: " + str(current_epoch) + ", iter: " + str(iteration) + "/" + str(len(data_loader_train)) + ", lr: " + str(optimizer.param_groups[0]["lr"]) + ", class_loss: " + str(loss.item())) 92 | 93 | if iteration >= len(data_loader_train): 94 | print("Epoch " + str(current_epoch) + "ended.") 95 | start_iter = 1 96 | models = {"model_state": model} 97 | checkpoints.save("model_ep_{:03}_iter_{:07d}".format(current_epoch, iteration), models, optimizer, output_dir, current_epoch, iteration) 98 | 99 | break 100 | 101 | checkpoints.save("final_model", models, optimizer, output_dir, current_epoch, iteration) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /utils/boxlist_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def area(boxes): 5 | area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 6 | return area 7 | 8 | def boxlist_iou(boxlist1, boxlist2): 9 | """Compute the intersection over union of two set of boxes. 10 | The box order must be (xmin, ymin, xmax, ymax). 11 | Arguments: 12 | box1: (BoxList) bounding boxes, sized [N,4]. 13 | box2: (BoxList) bounding boxes, sized [M,4]. 14 | Returns: 15 | (tensor) iou, sized [N,M]. 16 | """ 17 | 18 | # N = boxlist1.shape[0] 19 | # M = boxlist2.shape[1] 20 | 21 | area1 = area(boxlist1) 22 | area2 = area(boxlist2) 23 | 24 | lt = torch.max(boxlist1[:, None, :2], boxlist2[:, :2]) # [N,M,2] 25 | rb = torch.min(boxlist1[:, None, 2:], boxlist2[:, 2:]) # [N,M,2] 26 | 27 | wh = (rb - lt ).clamp(min=0) # [N,M,2] 28 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 29 | 30 | iou = inter / (area1[:, None] + area2 - inter) 31 | return iou 32 | 33 | def boxlist_distance(boxlist1, boxlist2, tau=1): 34 | """Compute the Euclidean distance between centers of two set of boxes. 35 | The box order must be (xmin, ymin, xmax, ymax). 36 | Arguments: 37 | box1: (BoxList) bounding boxes, sized [N,4]. 38 | box2: (BoxList) bounding boxes, sized [M,4]. 39 | Returns: 40 | (tensor) distance, sized [N,M]. 41 | """ 42 | center1 = torch.cat((((boxlist1[:, None, 2] - boxlist1[:, None, 0]) / 2) + boxlist1[:, None,0], ((boxlist1[:, None, 3] - boxlist1[:, None, 1]) / 2) + boxlist1[:, None,1]), dim=1) 43 | center2 = torch.cat((((boxlist2[:, None, 2] - boxlist2[:, None, 0]) / 2) + boxlist2[:, None,0], ((boxlist2[:, None, 3] - boxlist2[:, None, 1]) / 2) + boxlist2[:, None,1]), dim=1) 44 | 45 | center1 = center1.unsqueeze(1) 46 | center2 = center2.unsqueeze(0) 47 | 48 | d = torch.sqrt((center1[:, :, 0] - center2[:, :, 0]) ** 2 + (center1[:, :, 1] - center2[:, :, 1]) ** 2) 49 | 50 | return torch.exp(-1*tau*d) -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | def tag_last_checkpoint(output_dir, last_filename): 6 | save_file = os.path.join(output_dir, "last_checkpoint") 7 | with open(save_file, "w") as f: 8 | f.write(last_filename) 9 | 10 | 11 | def get_checkpoint_file(output_dir): 12 | save_file = os.path.join(output_dir, "last_checkpoint") 13 | try: 14 | with open(save_file, "r") as f: 15 | last_saved = f.read() 16 | last_saved = last_saved.strip() 17 | except IOError: 18 | last_saved = "" 19 | return last_saved 20 | 21 | 22 | def save(name, model, optimizer, output_dir, epoch, iteration): 23 | 24 | optimizer_state = optimizer.state_dict() 25 | 26 | data = {"optimizer_state": optimizer_state, 27 | "epoch": epoch, "iteration": iteration} 28 | for model_name, model_obj in model.items(): 29 | data[model_name] = model_obj.state_dict() 30 | 31 | save_file = os.path.join(output_dir, "{}.pth".format(name)) 32 | 33 | print("Saving checkpoint to {}".format(save_file)) 34 | torch.save(data, save_file) 35 | tag_last_checkpoint(output_dir, save_file) 36 | 37 | 38 | def load(output_dir): 39 | if os.path.exists(os.path.join(output_dir, "last_checkpoint")): 40 | f = get_checkpoint_file(output_dir) 41 | else: 42 | return {} 43 | print("Loading checkpoint from {}".format(f)) 44 | checkpoint = torch.load(f, map_location=torch.device("cpu")) 45 | data = {} 46 | data["optimizer_state"] = checkpoint.pop("optimizer_state") 47 | data["epoch"] = checkpoint.pop("epoch") 48 | data["iteration"] = checkpoint.pop("iteration") 49 | data["model_state"] = checkpoint.pop("model_state") 50 | return data --------------------------------------------------------------------------------