├── README.md
├── data
├── ava_dataset.py
└── collate_batch.py
├── images
├── graph.jpg
└── graph2.PNG
├── models
├── attention_layer.py
├── gat.py
└── model.py
├── test.py
├── train.py
└── utils
├── boxlist_ops.py
└── checkpoints.py
/README.md:
--------------------------------------------------------------------------------
1 | # STAGE: Spatio-Temporal Attention on Graph Entities
2 | This repository contains the train and test code for the paper _[STAGE: Spatio-Temporal Attention on Graph Entities for Video Action Detection](https://arxiv.org/abs/1912.04316)_
3 |
4 |
5 |
6 |
7 |
8 | ## Requirements
9 | The required Python packages are:
10 | * torch>=1.0.0
11 | * h5py>=2.8.0
12 | * tensorboardX>=1.6
13 |
14 | ## Features
15 | In order to train and test the module, you need pre-computed actors and objects features coming from a pre-trained backbone on the [AVA dataset](https://research.google.com/ava/). Features must be organized in h5py files as follows:
16 |
17 | **Actors features**
18 |
19 | actors_features_dir
20 | |
21 | |->
22 | | |-> .h5
23 | | |-> .h5
24 | | |
25 | |
26 | |->
27 | | |-> .h5
28 | | |-> .h5
29 | | |
30 | |
31 |
32 | Each .h5 file should contain the following data:
33 | * "features" -> a torch tensor with shape (num_actors, feature_size, t, h, w) containing actors features
34 | * "boxes" -> a torch tensor with shape (num_actors, 4) containing bounding boxes coordinates for each actor
35 | * "labels" -> a torch tensor with shape (num_actors, 81) containing ones and zeros for performed/not performed actions
36 |
37 | **Objects features**
38 |
39 | objects_features.h5
40 |
41 | The file should contain the following data:
42 | * "\_\_features" -> a torch tensor with shape (num_objects, feature_size) containing objects features
43 | * "\_\_boxes" -> a torch tensor with shape (num_objects, 4) containing bounding boxes coordinates for each object
44 | * "\_\_cls_prob" -> a torch tensor with shape (num_objects, num_classes) containing objects probabilities for each class
45 |
46 | For example, the objects features of the clipID '-5KQ66BBWC4' at timestamp '902' will be in
47 |
48 | objects_features["5KQ66BBWC4_902_features"]
49 |
50 | I3D actors features are available at the following links:
51 | * [[I3D_actors_train]](https://drive.google.com/open?id=1RlciPLrEQcY0uYecS_cEWydrpvWg9DZv)
52 | * [[I3D_actors_val]](https://drive.google.com/open?id=1HCjezdcr2BkVUIEJgzBKPYSYLA0a9vxw)
53 |
54 | Each tar.gz contains a directory, which corresponds to the "actors_features_dir" root.
55 |
56 | Faster-RCNN objects features are available at the following links:
57 | * [[Faster-RCNN_objects_train]](https://drive.google.com/file/d/13PrXvAR-Rw9MaTAJA5hInpJG4V_FLNuB/view?usp=sharing)
58 | * [[Faster-RCNN_objects_val]](https://drive.google.com/open?id=17_9NkM0kB_j0YEersD6y5WRPcKL6fiLp)
59 |
60 | Each tar.gz contains an h5py file, which corresponds to the "objects_features.h5" file.
61 |
62 | The size of all the features is ~90 GB.
63 |
64 | **SlowFast features**
65 |
66 | You can find SlowFast features at the following links:
67 | * [[slowfast_ava2.1_32x2_features_train]](https://drive.google.com/file/d/1DW0b3Cc4d64P5Ir40cxpquGwYXkA_g-P/view?usp=sharing)
68 | * [[slowfast_ava2.1_32x2_features_val]](https://drive.google.com/file/d/1GbCLQ5jK8tk5FBj_DCKouQwyEApkkUfk/view?usp=sharing)
69 | * [[slowfast_ava2.2_32x2_features_train]](https://drive.google.com/file/d/1dV1F1wYDBl4M8BRp8_uKlGG4Vszzv9dH/view?usp=sharing)
70 | * [[slowfast_ava2.2_32x2_features_val]](https://drive.google.com/file/d/1C9m0DvE0rEyrmG36RggUCLlZOd2AqbVO/view?usp=sharing)
71 | * [[slowfast_ava2.2_64x2_features_train]](https://drive.google.com/file/d/1D-W7mJsWAt843GA0IwLaAD7qOpxIQESC/view?usp=sharing)
72 | * [[slowfast_ava2.2_64x2_features_val]](https://drive.google.com/file/d/1_Wd89_kQYtwL5IBmj7skJB0uVEUaW5V8/view?usp=sharing)
73 |
74 | **Note**: These features are organized differently from I3D ones (you should write a specific dataloader or modify the provided one):
75 | each 'h5py' file is a dictionary, each containing keys in the format "\_\". For each key, the corresponding value is another dictionary with keys "boxes", "features", "labels", containing actors' boxes coorindates, features extracted from the last SlowFast layer before classification and ground truth labels for that specific clip.
76 |
77 |
78 | ## Training
79 |
80 | Run `python train.py` using the following arguments:
81 |
82 | | Argument | Value |
83 | |------|------|
84 | | `--actors_dir` | Path to the train actors_features_dir |
85 | | `--objects_file ` | Path to the train objects_features.h5 file |
86 | | `--output_dir ` | Path to the directory where checkpoints will be stored |
87 | | `--log_tensorboard_dir ` | Path to the directory where tensorboard logs will be stored |
88 | | `--batch_size ` | The batch size. Must be > 1 to allow temporal connections |
89 | | `--n_workers ` | The number of workers |
90 | | `--lr ` | The learning rate |
91 |
92 | For example, use:
93 | ```
94 | python train.py --actors_dir "./actors_features_dir" --objects_file "./objects_features.h5" --output_dir "./out_checkpoints" --log_tensorboard_dir "./out_tensorboard" --batch_size 6 --n_workers 8 --lr 0.0000625
95 | ```
96 |
97 | ## Testing
98 |
99 | Run `python test.py` using the following arguments:
100 |
101 | | Argument | Value |
102 | |------|------|
103 | | `--actors_dir` | Path to the val actors_features_dir |
104 | | `--objects_file ` | Path to the val objects_features.h5 file |
105 | | `--output_dir ` | Path to the directory where the checkpoint to load is stored |
106 | | `--batch_size ` | The batch size. Must be > 1 to allow temporal connections |
107 | | `--n_workers ` | The number of workers |
108 |
109 | A "results.csv" file will be created under the "output_dir" directory, which should be used for evaluation as explained [here](https://research.google.com/ava/download.html)
110 |
111 | For example, use:
112 | ```
113 | python test.py --actors_dir "./actors_features_dir" --objects_file "./objects_features.h5" --output_dir "./out_checkpoints" --batch_size 6 --n_workers 8
114 | ```
115 |
116 |
--------------------------------------------------------------------------------
/data/ava_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data_utl
3 | import h5py
4 | import os
5 | import random
6 |
7 |
8 | def sort_function(filename):
9 | return (filename.split('/')[-2], int(filename.split('/')[-1].split('.')[-2]))
10 |
11 | class AVADataset(data_utl.Dataset):
12 |
13 | def __init__(self, split='train', videodir='./train_features_I3D', objectsfile='./ava_objects_fasterrcnn.hdf5'):
14 | self.split = split
15 | self.objectsfile = objectsfile
16 |
17 | self.filenames = []
18 | for dirname in os.listdir(videodir):
19 | for filename in os.listdir(os.path.join(videodir, dirname)):
20 | self.filenames.append(os.path.join(videodir, dirname, filename))
21 |
22 | if self.split == "val":
23 | self.filenames.sort(key=sort_function)
24 |
25 | def __getitem__(self, index):
26 | filename = self.filenames[index]
27 | clip_id = filename.split('/')[-2]
28 | timestamp = filename.split('/')[-1].split('.')[0]
29 |
30 | hf_actors = h5py.File(filename, 'r')
31 | actors_features = torch.from_numpy(hf_actors.get("features").value)
32 | actors_labels = torch.from_numpy(hf_actors.get('labels').value)
33 | actors_boxes = torch.from_numpy(hf_actors.get('boxes').value)
34 |
35 | hf_objects = h5py.File(self.objectsfile, 'r')
36 | objects_features = torch.from_numpy(hf_objects.get(clip_id + '_' + timestamp.lstrip("0") + '_' + 'features').value)
37 | objects_boxes = torch.from_numpy(hf_objects.get(clip_id + '_' + timestamp.lstrip("0") + '_' + 'boxes').value)
38 |
39 | return actors_features, actors_labels, actors_boxes, [(clip_id, timestamp) for _ in range(actors_features.shape[0])], objects_features, objects_boxes, [(clip_id, timestamp) for _ in range(objects_features.shape[0])]
40 |
41 | def __len__(self):
42 | return len(self.filenames)
43 |
44 | def rotate(self, n):
45 | self.filenames = self.filenames[n:] + self.filenames[:n]
46 |
47 | def shuffle_filename_blocks(self, block_size, epoch):
48 | self.filenames.sort(key=sort_function)
49 | self.rotate(int(block_size/4)*(epoch-1))
50 | self.filenames = [self.filenames[i:i+block_size] for i in range(0,len(self.filenames),block_size)]
51 | random.shuffle(self.filenames)
52 | self.filenames[:] = [b for bs in self.filenames for b in bs]
53 |
--------------------------------------------------------------------------------
/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import boxlist_ops
3 |
4 |
5 | class BatchCollator(object):
6 |
7 | def __call__(self, batch):
8 | transposed_batch = list(zip(*batch))
9 | actors_features = torch.cat(transposed_batch[0], dim=0)
10 | actors_labels = torch.cat(transposed_batch[1], dim=0)
11 | actors_boxes = torch.cat(transposed_batch[2], dim=0)
12 | actors_filenames = sum(transposed_batch[3], [])
13 |
14 | objects_features = torch.cat(transposed_batch[4], dim=0)
15 | objects_boxes = torch.cat(transposed_batch[5], dim=0)
16 | objects_filenames = sum(transposed_batch[6], [])
17 |
18 | num_actor_proposals = actors_boxes.shape[0]
19 | num_object_proposals = objects_boxes.shape[0]
20 |
21 | adj = torch.zeros((num_actor_proposals + num_object_proposals, num_actor_proposals + num_object_proposals))
22 |
23 | cur_actors = 0
24 | cur_objects = 0
25 |
26 | tau=1
27 | # populate the adj matrix in the actor-actor and actor-object sections
28 | for i in range(len(transposed_batch[0])):
29 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0],cur_actors:cur_actors + transposed_batch[0][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i], tau=tau)
30 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i], tau=tau)
31 | if i==0:
32 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors + transposed_batch[0][i].shape[0]:cur_actors + transposed_batch[0][i].shape[0] + transposed_batch[0][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i+1], tau=tau)
33 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i+1], tau=tau)
34 | elif i == len(transposed_batch[3]) - 1:
35 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors - transposed_batch[0][i-1].shape[0]:cur_actors] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i-1], tau=tau)
36 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i-1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i-1], tau=tau)
37 | else:
38 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors + transposed_batch[0][i].shape[0]:cur_actors + transposed_batch[0][i].shape[0] + transposed_batch[0][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i + 1], tau=tau)
39 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], cur_actors - transposed_batch[0][i - 1].shape[0]:cur_actors] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[2][i - 1], tau=tau)
40 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i + 1], tau=tau)
41 | adj[cur_actors:cur_actors + transposed_batch[0][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i - 1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[2][i], transposed_batch[5][i - 1], tau=tau)
42 |
43 | cur_actors += transposed_batch[0][i].shape[0]
44 | cur_objects += transposed_batch[4][i].shape[0]
45 |
46 | # populate the adj matrix in the object-actor section
47 | adj[num_actor_proposals:, :num_actor_proposals] = torch.t(adj[:num_actor_proposals, num_actor_proposals:])
48 | cur_objects = 0
49 |
50 | # populate the adj matrix in the object-object section
51 | for i in range(len(transposed_batch[4])):
52 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i], tau=tau)
53 | if i==0:
54 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i+1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i+1], tau=tau)
55 | elif i == len(transposed_batch[3]) - 1:
56 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i-1].shape[0]:num_actor_proposals+cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i-1], tau=tau)
57 | else:
58 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0]:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0] + transposed_batch[4][i + 1].shape[0]] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i + 1], tau=tau)
59 | adj[num_actor_proposals + cur_objects:num_actor_proposals + cur_objects + transposed_batch[4][i].shape[0], num_actor_proposals + cur_objects - transposed_batch[4][i - 1].shape[0]:num_actor_proposals + cur_objects] = boxlist_ops.boxlist_distance(transposed_batch[5][i], transposed_batch[5][i - 1], tau=tau)
60 |
61 | cur_objects += transposed_batch[4][i].shape[0]
62 |
63 | return actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj
--------------------------------------------------------------------------------
/images/graph.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/STAGE_action_detection/a5d76114f47c103deef79ac6056b1794961fb58c/images/graph.jpg
--------------------------------------------------------------------------------
/images/graph2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/STAGE_action_detection/a5d76114f47c103deef79ac6056b1794961fb58c/images/graph2.PNG
--------------------------------------------------------------------------------
/models/attention_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GraphAttentionLayer(nn.Module):
7 | def __init__(self, in_features, out_features, dropout, alpha):
8 | super(GraphAttentionLayer, self).__init__()
9 | self.dropout = dropout
10 | self.in_features = in_features
11 | self.out_features = out_features
12 | self.alpha = alpha
13 |
14 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
15 | nn.init.xavier_uniform_(self.W.data, gain=1.414)
16 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
17 | nn.init.xavier_uniform_(self.a.data, gain=1.414)
18 |
19 | self.leakyrelu = nn.LeakyReLU(self.alpha)
20 |
21 | def forward(self, input, adj):
22 | h = torch.mm(input, self.W)
23 | N = h.size()[0]
24 |
25 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
26 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
27 |
28 | zero_vec = -9e15 * torch.ones_like(e)
29 | attention = torch.where(adj > 0, e*adj, zero_vec)
30 | tau = 1
31 | attention = F.softmax(attention / tau, dim=1)
32 | h_prime = torch.matmul(attention, h)
33 | h_prime = F.dropout(h_prime, self.dropout, training=self.training)
34 |
35 | return F.elu(h_prime)
36 |
37 |
--------------------------------------------------------------------------------
/models/gat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.attention_layer import GraphAttentionLayer
5 |
6 |
7 | class GAT(nn.Module):
8 | def __init__(self, nfeat, nhid, dropout, alpha, nheads):
9 | super(GAT, self).__init__()
10 | self.dropout = dropout
11 |
12 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha) for _ in range(nheads)]
13 | for i, attention in enumerate(self.attentions):
14 | self.add_module('attention_{}'.format(i), attention)
15 |
16 | def forward(self, x, adj):
17 | x = F.dropout(x, self.dropout, training=self.training)
18 | x = torch.cat([v(x, adj) for k, v in self._modules.items() if k.startswith("attention")], dim=1)
19 | return x
20 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.gat import GAT
3 | from torch import nn
4 |
5 |
6 | class Stage(torch.nn.Module):
7 |
8 | def __init__(self, num_classes, actors_features_size, objects_features_size, n_heads):
9 | super(Stage, self).__init__()
10 |
11 | self.num_classes = num_classes
12 | self.actors_features_size = actors_features_size
13 | self.objects_features_size = objects_features_size
14 | self.n_heads = n_heads
15 |
16 | if self.objects_features_size > self.actors_features_size:
17 | self.obj_reducer = nn.Linear(in_features=self.objects_features_size, out_features=self.actors_features_size)
18 |
19 | self.gat1 = GAT(self.actors_features_size + 4, int(self.actors_features_size/self.n_heads) + int(4/self.n_heads), 0.5, 0.2, self.n_heads) #we add 4 because we are going to add h,w,xc,yc to the channel axis
20 | self.gat_fc = nn.Linear(in_features=self.actors_features_size +4, out_features=self.actors_features_size +4)
21 | self.l_norm = nn.LayerNorm(self.actors_features_size +4)
22 |
23 | self.gat2 = GAT(self.actors_features_size + 4, int(self.actors_features_size/self.n_heads) + int(4/self.n_heads), 0.5, 0.2, self.n_heads)
24 | self.gat_fc2 = nn.Linear(in_features=self.actors_features_size +4, out_features=self.actors_features_size +4)
25 | self.l_norm2 = nn.LayerNorm(self.actors_features_size +4)
26 |
27 | self.logits = nn.Linear(in_features=self.actors_features_size +4, out_features=self.num_classes)
28 |
29 |
30 | def forward(self, actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj):
31 | #compute h, w, xc, yc for each actor/object
32 | actors_h = actors_boxes[:, 3] - actors_boxes[:, 1]
33 | objects_h = objects_boxes[:, 3] - objects_boxes[:, 1]
34 | actors_w = actors_boxes[:, 2] - actors_boxes[:, 0]
35 | objects_w = objects_boxes[:, 2] - objects_boxes[:, 0]
36 | actors_centers_x = ((actors_boxes[:, 2] - actors_boxes[:, 0]) / 2) + actors_boxes[:, 0]
37 | actors_centers_y = ((actors_boxes[:, 3] - actors_boxes[:, 1]) / 2) + actors_boxes[:, 1]
38 | objects_centers_x = ((objects_boxes[:, 2] - objects_boxes[:, 0]) / 2) + objects_boxes[:, 0]
39 | objects_centers_y = ((objects_boxes[:, 3] - objects_boxes[:, 1]) / 2) + objects_boxes[:, 1]
40 |
41 | with torch.no_grad():
42 | actors_features = torch.mean(torch.mean(torch.mean(actors_features, dim=2), dim=-1), dim=-1)
43 | actors_features = torch.cat((actors_features, actors_h.unsqueeze(1), actors_w.unsqueeze(1), actors_centers_x.unsqueeze(1), actors_centers_y.unsqueeze(1)), dim=1)
44 |
45 | if self.objects_features_size > self.actors_features_size:
46 | objects_features = nn.functional.relu(self.obj_reducer(objects_features))
47 |
48 | objects_features = torch.cat((objects_features, objects_h.unsqueeze(1), objects_w.unsqueeze(1), objects_centers_x.unsqueeze(1), objects_centers_y.unsqueeze(1)), dim=1)
49 |
50 | all_features = torch.cat((actors_features, objects_features), dim=0)
51 |
52 | gat_pred = self.gat1(all_features, adj)
53 | all_features = all_features + self.gat_fc(gat_pred)
54 | all_features = self.l_norm(all_features)
55 |
56 | gat_pred = self.gat2(all_features, adj)
57 | all_features = all_features + self.gat_fc2(gat_pred)
58 | all_features = self.l_norm2(all_features)
59 |
60 | pred = self.logits(all_features[:actors_features.shape[0], :])
61 |
62 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, actors_labels)
63 |
64 | pred = torch.sigmoid(pred)
65 | return pred, loss
66 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from data.ava_dataset import AVADataset
2 | import torch
3 | from utils import checkpoints
4 | import os
5 | import argparse
6 | from data.collate_batch import BatchCollator
7 | from models.model import Stage
8 | import csv
9 | from tqdm import tqdm
10 | import sys
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--actors_dir", type=str, help="path to the directory containing actors features")
15 | parser.add_argument("--objects_dir", type=str, help="path to the file containing objects features")
16 | parser.add_argument("--output_dir", type=str, help="path to the directory where checkpoints will be stored")
17 | parser.add_argument("--batch_size", type=int)
18 | parser.add_argument("--n_workers", type=int)
19 | parser.add_argument("--num_classes", type=int, default=81)
20 | parser.add_argument("--actors_features_size", type=int, default=1024)
21 | parser.add_argument("--objects_features_size", type=int, default=2048)
22 | parser.add_argument("--n_heads", type=int, default=4, help="only 2 or 4 heads supported at the moment")
23 |
24 |
25 | def main():
26 | args = parser.parse_args()
27 |
28 | num_classes = args.num_classes
29 | output_dir = args.output_dir
30 |
31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32 |
33 | assert os.path.isdir(output_dir)
34 |
35 | ava_val = AVADataset(split='val', videodir=args.actors_dir, objectsfile=args.objects_dir)
36 | data_loader_val = torch.utils.data.DataLoader(ava_val, batch_size=args.batch_size, num_workers=args.n_workers, collate_fn=BatchCollator(), shuffle=False)
37 |
38 | model = Stage(num_classes, args.actors_features_size, args.objects_features_size, args.n_heads)
39 |
40 | model.eval()
41 | model.to(device)
42 |
43 | checkpoint = checkpoints.load(output_dir)
44 | if checkpoint:
45 | model.load_state_dict(checkpoint["model_state"], strict=False)
46 | else:
47 | sys.exit('No checkpoint found!')
48 |
49 | with torch.no_grad():
50 | with open(os.path.join(output_dir, "results.csv"), mode='w') as csv_file:
51 | csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
52 |
53 | for iteration_val, (actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj) in enumerate(tqdm(data_loader_val), 1):
54 | actors_features = actors_features.to(device)
55 | actors_labels = actors_labels.to(device)
56 | actors_boxes = actors_boxes.to(device)
57 | objects_features = objects_features.to(device)
58 | objects_boxes = objects_boxes.to(device)
59 | adj = adj.to(device)
60 |
61 | pred, loss = model(actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj)
62 |
63 | for i, prop in enumerate(pred):
64 | classes = torch.nonzero(prop)
65 | for c in classes:
66 | if int(c) != 0: # do not consider background class
67 | csv_writer.writerow(
68 | [actors_filenames[i][0], str(int(actors_filenames[i][1])),
69 | str(actors_boxes[i, 0].item()), str(actors_boxes[i, 1].item()),
70 | str(actors_boxes[i, 2].item()), str(actors_boxes[i, 3].item()),
71 | int(c),
72 | prop[int(c)].item()])
73 | csv_file.flush()
74 |
75 | csv_file.close()
76 |
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from data.ava_dataset import AVADataset
2 | import torch
3 | from utils import checkpoints
4 | from tensorboardX import SummaryWriter
5 | import os
6 | import errno
7 | import argparse
8 | from data.collate_batch import BatchCollator
9 | from models.model import Stage
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--actors_dir", type=str, help="path to the directory containing actors features")
14 | parser.add_argument("--objects_file", type=str, help="path to the file containing objects features")
15 | parser.add_argument("--output_dir", type=str, help="path to the directory where checkpoints will be stored")
16 | parser.add_argument("--log_tensorboard_dir", type=str, help="path to the directory where tensorboard logs will be stored")
17 | parser.add_argument("--batch_size", type=int)
18 | parser.add_argument("--n_workers", type=int)
19 | parser.add_argument("--lr", type=float)
20 | parser.add_argument("--num_classes", type=int, default=81)
21 | parser.add_argument("--actors_features_size", type=int, default=1024)
22 | parser.add_argument("--objects_features_size", type=int, default=2048)
23 | parser.add_argument("--n_heads", type=int, default=4, help="only 2 or 4 heads supported at the moment")
24 | parser.add_argument("--impose_lr", type=float, default=0.0, help="if not zero, will impose the specified learning rate to all the optimizer's parameters")
25 | parser.add_argument("--n_epochs", type=int, default=50)
26 |
27 |
28 | def main():
29 | args = parser.parse_args()
30 |
31 | num_classes = args.num_classes
32 | start_epoch = 1
33 | start_iter = 1
34 | output_dir = args.output_dir
35 | tensorboard_dir = args.log_tensorboard_dir
36 |
37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38 |
39 | try:
40 | os.makedirs(output_dir)
41 | except OSError as e:
42 | if e.errno != errno.EEXIST:
43 | raise
44 |
45 | ava_train = AVADataset(split='train', videodir=args.actors_dir, objectsfile=args.objects_file)
46 | data_loader_train = torch.utils.data.DataLoader(ava_train, batch_size=args.batch_size, num_workers=args.n_workers, collate_fn=BatchCollator(), shuffle=False)
47 |
48 | writer = SummaryWriter(tensorboard_dir)
49 |
50 | model = Stage(num_classes, args.actors_features_size, args.objects_features_size, args.n_heads)
51 |
52 | model.train()
53 | model.to(device)
54 |
55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
56 |
57 | checkpoint = checkpoints.load(output_dir)
58 | if checkpoint:
59 | model.load_state_dict(checkpoint["model_state"], strict=False)
60 | optimizer.load_state_dict(checkpoint["optimizer_state"])
61 | if args.impose_lr != 0.0:
62 | for param_group in optimizer.param_groups:
63 | param_group['lr'] = args.impose_lr
64 | start_epoch = checkpoint["epoch"]
65 | start_iter = checkpoint["iteration"] + 1
66 | else:
67 | print("No checkpoint found: initializing model from scratch")
68 |
69 | for current_epoch in range(start_epoch, args.n_epochs + 1):
70 |
71 | data_loader_train.dataset.shuffle_filename_blocks(args.batch_size, current_epoch)
72 |
73 | for iteration, (actors_features, actors_labels, actors_boxes, actors_filenames, objects_features, objects_boxes, objects_filenames, adj) in enumerate(data_loader_train, start_iter):
74 | actors_features = actors_features.to(device)
75 | actors_labels = actors_labels.to(device)
76 | actors_boxes = actors_boxes.to(device)
77 | objects_features = objects_features.to(device)
78 | objects_boxes = objects_boxes.to(device)
79 | adj = adj.to(device)
80 |
81 | pred, loss = model(actors_features, actors_labels, actors_boxes, objects_features, objects_boxes, adj)
82 |
83 | loss = torch.mean(loss)
84 |
85 | optimizer.zero_grad()
86 | loss.backward()
87 | optimizer.step()
88 |
89 | if iteration % 20 == 0:
90 | writer.add_scalar('train/class_loss', loss, (current_epoch-1) * len(data_loader_train) + iteration)
91 | print("epoch: " + str(current_epoch) + ", iter: " + str(iteration) + "/" + str(len(data_loader_train)) + ", lr: " + str(optimizer.param_groups[0]["lr"]) + ", class_loss: " + str(loss.item()))
92 |
93 | if iteration >= len(data_loader_train):
94 | print("Epoch " + str(current_epoch) + "ended.")
95 | start_iter = 1
96 | models = {"model_state": model}
97 | checkpoints.save("model_ep_{:03}_iter_{:07d}".format(current_epoch, iteration), models, optimizer, output_dir, current_epoch, iteration)
98 |
99 | break
100 |
101 | checkpoints.save("final_model", models, optimizer, output_dir, current_epoch, iteration)
102 |
103 |
104 | if __name__ == '__main__':
105 | main()
106 |
--------------------------------------------------------------------------------
/utils/boxlist_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def area(boxes):
5 | area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
6 | return area
7 |
8 | def boxlist_iou(boxlist1, boxlist2):
9 | """Compute the intersection over union of two set of boxes.
10 | The box order must be (xmin, ymin, xmax, ymax).
11 | Arguments:
12 | box1: (BoxList) bounding boxes, sized [N,4].
13 | box2: (BoxList) bounding boxes, sized [M,4].
14 | Returns:
15 | (tensor) iou, sized [N,M].
16 | """
17 |
18 | # N = boxlist1.shape[0]
19 | # M = boxlist2.shape[1]
20 |
21 | area1 = area(boxlist1)
22 | area2 = area(boxlist2)
23 |
24 | lt = torch.max(boxlist1[:, None, :2], boxlist2[:, :2]) # [N,M,2]
25 | rb = torch.min(boxlist1[:, None, 2:], boxlist2[:, 2:]) # [N,M,2]
26 |
27 | wh = (rb - lt ).clamp(min=0) # [N,M,2]
28 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
29 |
30 | iou = inter / (area1[:, None] + area2 - inter)
31 | return iou
32 |
33 | def boxlist_distance(boxlist1, boxlist2, tau=1):
34 | """Compute the Euclidean distance between centers of two set of boxes.
35 | The box order must be (xmin, ymin, xmax, ymax).
36 | Arguments:
37 | box1: (BoxList) bounding boxes, sized [N,4].
38 | box2: (BoxList) bounding boxes, sized [M,4].
39 | Returns:
40 | (tensor) distance, sized [N,M].
41 | """
42 | center1 = torch.cat((((boxlist1[:, None, 2] - boxlist1[:, None, 0]) / 2) + boxlist1[:, None,0], ((boxlist1[:, None, 3] - boxlist1[:, None, 1]) / 2) + boxlist1[:, None,1]), dim=1)
43 | center2 = torch.cat((((boxlist2[:, None, 2] - boxlist2[:, None, 0]) / 2) + boxlist2[:, None,0], ((boxlist2[:, None, 3] - boxlist2[:, None, 1]) / 2) + boxlist2[:, None,1]), dim=1)
44 |
45 | center1 = center1.unsqueeze(1)
46 | center2 = center2.unsqueeze(0)
47 |
48 | d = torch.sqrt((center1[:, :, 0] - center2[:, :, 0]) ** 2 + (center1[:, :, 1] - center2[:, :, 1]) ** 2)
49 |
50 | return torch.exp(-1*tau*d)
--------------------------------------------------------------------------------
/utils/checkpoints.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 |
5 | def tag_last_checkpoint(output_dir, last_filename):
6 | save_file = os.path.join(output_dir, "last_checkpoint")
7 | with open(save_file, "w") as f:
8 | f.write(last_filename)
9 |
10 |
11 | def get_checkpoint_file(output_dir):
12 | save_file = os.path.join(output_dir, "last_checkpoint")
13 | try:
14 | with open(save_file, "r") as f:
15 | last_saved = f.read()
16 | last_saved = last_saved.strip()
17 | except IOError:
18 | last_saved = ""
19 | return last_saved
20 |
21 |
22 | def save(name, model, optimizer, output_dir, epoch, iteration):
23 |
24 | optimizer_state = optimizer.state_dict()
25 |
26 | data = {"optimizer_state": optimizer_state,
27 | "epoch": epoch, "iteration": iteration}
28 | for model_name, model_obj in model.items():
29 | data[model_name] = model_obj.state_dict()
30 |
31 | save_file = os.path.join(output_dir, "{}.pth".format(name))
32 |
33 | print("Saving checkpoint to {}".format(save_file))
34 | torch.save(data, save_file)
35 | tag_last_checkpoint(output_dir, save_file)
36 |
37 |
38 | def load(output_dir):
39 | if os.path.exists(os.path.join(output_dir, "last_checkpoint")):
40 | f = get_checkpoint_file(output_dir)
41 | else:
42 | return {}
43 | print("Loading checkpoint from {}".format(f))
44 | checkpoint = torch.load(f, map_location=torch.device("cpu"))
45 | data = {}
46 | data["optimizer_state"] = checkpoint.pop("optimizer_state")
47 | data["epoch"] = checkpoint.pop("epoch")
48 | data["iteration"] = checkpoint.pop("iteration")
49 | data["model_state"] = checkpoint.pop("model_state")
50 | return data
--------------------------------------------------------------------------------