├── LICENSE ├── generate_graph.py ├── models_gnn.py ├── README.md ├── train_val.py └── data_loader.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kyle Min 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /generate_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from data_loader import AVADataset 4 | 5 | parser = argparse.ArgumentParser(description='generate_graph') 6 | parser.add_argument('--feature', type=str, default='resnet18-tsm-aug', help='name of the features') 7 | parser.add_argument('--numv', type=int, default=2000, help='number of nodes') 8 | parser.add_argument('--time_edge', type=float, default=0.9, help='time threshold') 9 | parser.add_argument('--cross_identity', type=str, default='cin', help='whether to allow cross-identity edges') 10 | parser.add_argument('--edge_weight', type=str, default='fsimy', help='how to decide edge weights') 11 | 12 | 13 | def main(): 14 | args = parser.parse_args() 15 | 16 | # dict that stores graph parameters 17 | graph_data={} 18 | graph_data['numv'] = args.numv 19 | graph_data['skip'] = graph_data['numv'] ## if 'skip' < 'numv' then there will be overlap between graphs of length numv-skip 20 | graph_data['time_edge'] = args.time_edge ## time support of the graph 21 | graph_data['cross_identity'] = args.cross_identity ## 'ciy' allows cross-identity edges, 'cin': No cross-idenity edges 22 | graph_data['edge_weight'] = args.edge_weight ## fsimn vs fsimy as above 23 | 24 | # target path for storing graphs 25 | tpath_key = os.path.join('graphs', '{}_{}_{}_{}_{}'.format(args.feature, graph_data['numv'], graph_data['time_edge'], graph_data['cross_identity'], graph_data['edge_weight'])) 26 | 27 | for mode in ['train', 'val']: 28 | # specifies location of the features within feature path 29 | dpath_mode = os.path.join('features', args.feature, '{}_forward'.format(mode), '*.csv') 30 | 31 | # specifies location of the graphs 32 | tpath_mode = os.path.join(tpath_key, mode) 33 | 34 | graph_gen(dpath_mode, tpath_mode, graph_data, mode) 35 | 36 | 37 | # function that takes input of feature path and target path for storing graphs and creates graphs using the dataloader AVADataset 38 | def graph_gen(dpath, tpath, graph_data, mode, cont=0): 39 | os.makedirs(tpath, exist_ok=True) 40 | Fdataset = AVADataset(dpath, graph_data, cont, tpath, mode) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /models_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import BatchNorm, SAGEConv, EdgeConv 5 | from torch_geometric.utils.dropout import dropout_adj 6 | 7 | 8 | class SPELL(torch.nn.Module): 9 | def __init__(self, channels, feature_dim=1024, dropout=0, dropout_a=0, da_true=False, proj_dim=64): 10 | self.channels = channels 11 | self.feature_dim = feature_dim 12 | self.dropout = dropout 13 | self.dropout_a = dropout_a 14 | self.da_true = da_true 15 | super(SPELL, self).__init__() 16 | 17 | self.layerspf = nn.Linear(4, proj_dim) # projection layer for spatial features (4 -> 64) 18 | self.layer011 = nn.Linear(self.feature_dim//2+proj_dim, self.channels[0]) 19 | self.layer012 = nn.Linear(self.feature_dim//2, self.channels[0]) 20 | 21 | self.batch01 = BatchNorm(self.channels[0]) 22 | 23 | self.layer11 = EdgeConv(nn.Sequential(nn.Linear(2*self.channels[0], self.channels[0]), nn.ReLU(), nn.Linear(self.channels[0], self.channels[0]))) 24 | self.batch11 = BatchNorm(self.channels[0]) 25 | self.layer12 = EdgeConv(nn.Sequential(nn.Linear(2*self.channels[0], self.channels[0]), nn.ReLU(), nn.Linear(self.channels[0], self.channels[0]))) 26 | self.batch12 = BatchNorm(self.channels[0]) 27 | self.layer13 = EdgeConv(nn.Sequential(nn.Linear(2*self.channels[0], self.channels[0]), nn.ReLU(), nn.Linear(self.channels[0], self.channels[0]))) 28 | self.batch13 = BatchNorm(self.channels[0]) 29 | 30 | self.layer21 = SAGEConv(self.channels[0], self.channels[1]) 31 | self.batch21 = BatchNorm(self.channels[1]) 32 | 33 | self.layer31 = SAGEConv(self.channels[1], 1) 34 | self.layer32 = SAGEConv(self.channels[1], 1) 35 | self.layer33 = SAGEConv(self.channels[1], 1) 36 | 37 | def forward(self, data): 38 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 39 | 40 | spf = x[:, self.feature_dim:self.feature_dim+4] # coordinates for the spatial features (dim: 4) 41 | edge_index1 = edge_index[:, edge_attr>=0] 42 | edge_index2 = edge_index[:, edge_attr<=0] 43 | 44 | x_visual = self.layer011(torch.cat((x[:,self.feature_dim//2:self.feature_dim], self.layerspf(spf)), dim=1)) 45 | x_audio = self.layer012(x[:,:self.feature_dim//2]) 46 | x = x_audio + x_visual 47 | 48 | x = self.batch01(x) 49 | x = F.relu(x) 50 | 51 | edge_index1m, _ = dropout_adj(edge_index=edge_index1, p=self.dropout_a, training=self.training if not self.da_true else True) 52 | x1 = self.layer11(x, edge_index1m) 53 | x1 = self.batch11(x1) 54 | x1 = F.relu(x1) 55 | x1 = F.dropout(x1, p=self.dropout, training=self.training) 56 | x1 = self.layer21(x1, edge_index1) 57 | x1 = self.batch21(x1) 58 | x1 = F.relu(x1) 59 | x1 = F.dropout(x1, p=self.dropout, training=self.training) 60 | 61 | edge_index2m, _ = dropout_adj(edge_index=edge_index2, p=self.dropout_a, training=self.training if not self.da_true else True) 62 | x2 = self.layer12(x, edge_index2m) 63 | x2 = self.batch12(x2) 64 | x2 = F.relu(x2) 65 | x2 = F.dropout(x2, p=self.dropout, training=self.training) 66 | x2 = self.layer21(x2, edge_index2) 67 | x2 = self.batch21(x2) 68 | x2 = F.relu(x2) 69 | x2 = F.dropout(x2, p=self.dropout, training=self.training) 70 | 71 | # Undirected graph 72 | edge_index3m, _ = dropout_adj(edge_index=edge_index, p=self.dropout_a, training=self.training if not self.da_true else True) 73 | x3 = self.layer13(x, edge_index3m) 74 | x3 = self.batch13(x3) 75 | x3 = F.relu(x3) 76 | x3 = F.dropout(x3, p=self.dropout, training=self.training) 77 | x3 = self.layer21(x3, edge_index) 78 | x3 = self.batch21(x3) 79 | x3 = F.relu(x3) 80 | x3 = F.dropout(x3, p=self.dropout, training=self.training) 81 | 82 | x1 = self.layer31(x1, edge_index1) 83 | x2 = self.layer32(x2, edge_index2) 84 | x3 = self.layer33(x3, edge_index) 85 | 86 | x = x1 + x2 + x3 87 | x = torch.sigmoid(x) 88 | 89 | return x 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | :exclamation:**Please consider using the most recent version of our graph learning framework: [GraVi-T](https://github.com/IntelLabs/GraVi-T)** 2 | 3 | # SPELL 4 | Learning Long-Term Spatial-Temporal Graphs for Active Speaker Detection (ECCV 2022)\ 5 | [**paper**](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136950367.pdf) | [**poster**](https://drive.google.com/file/d/1q4ds3p1X7mfdpvROMYrBChrt2Zr55sfx/view?usp=sharing) | [**presentation**](https://youtu.be/wqb3crJ47KM) 6 | 7 | ## Overview 8 | SPELL is a novel spatial-temporal graph learning framework for active speaker detection (ASD). It can model a minute-long temporal context without relying on computationally expensive networks. Through extensive experiments on the AVA-ActiveSpeaker dataset, we demonstrate that learning graph-based representations significantly improves the detection performance thanks to its explicit spatial and temporal structure. Specifically, SPELL outperforms all previous state-of-the-art approaches while requiring significantly lower memory and computation resources. 9 | 10 | ## Ego4D Challenges 11 | SPELL and its improved version ([STHG](https://arxiv.org/abs/2306.10608)) achieved 1st place in the Ego4D Challenges [@ECCV22](https://ego4d-data.org/workshops/eccv22/) and [@CVPR23](https://ego4d-data.org/workshops/cvpr23/), respectively. We summarize ASD performance comparisons on the validation set of the Ego4D dataset: 12 | | ASD Model | ASD mAP(%)↑ | ASD mAP@0.5(%)↑ | 13 | |:------------|:-------------------:|:-----------------------:| 14 | | RegionCls | - | 24.6 | 15 | | TalkNet | - | 50.6 | 16 | | SPELL | 71.3 | 60.7 | 17 | | [STHG](https://arxiv.org/abs/2306.10608) | **75.7** | **63.7** | 18 | 19 | :bulb:In this table, We report two metrics to evaluate ASD performance: mAP quantifies the ASD results by assuming that the face bound-box detections are the ground truth (i.e. assuming the perfect face detector), whereas mAP@0.5 quantifies the ASD results on the detected face bounding boxes (i.e. a face detection is considered positive only if the IoU between a detected face bounding box and the ground-truth exceeds 0.5). For more information, please refer to our technical reports for the challenge. 20 | 21 | :bulb:We computed mAP@0.5 by using [Ego4D's official evaluation tool](https://github.com/EGO4D/audio-visual/tree/main/active-speaker-detection/active_speaker/active_speaker_evaluation) 22 | 23 | ## ActivityNet 2022 24 | SPELL achieved 2nd place in the [AVA-ActiveSpeaker Challenge](https://research.google.com/ava/challenge.html) at ActivityNet 2022. For the challenge, we used a visual input spanning a longer period of time (23 consecutive face-crops instead of 11). We also found that using a larger `channel1` can further boost the performance.\ 25 | [**tech report**](https://static.googleusercontent.com/media/research.google.com/en//ava/2022/S2_SPELL_ActivityNet_Challenge_2022.pdf) | [**presentation**](https://youtu.be/WCOOxsY0z34) 26 | 27 | ## Dependency 28 | We used python=3.6, pytorch=1.9.1, and torch-geometric=2.0.3 in our experiments. 29 | 30 | ## Code Usage 31 | 1) Download the audio-visual features and the annotation csv files from [Google Drive](https://drive.google.com/drive/folders/1_vr3Wxf6yZRA3IjWgelnf0TQqzKzDNeu?usp=sharing). The directories should look like as follows: 32 | ``` 33 | |-- features 34 | |-- resnet18-tsm-aug 35 | |-- train_forward 36 | |-- val_forward 37 | |-- resnet50-tsm-aug 38 | |-- train_forward 39 | |-- val_forward 40 | |-- csv_files 41 | |-- ava_activespeaker_train.csv 42 | |-- ava_activespeaker_val.csv 43 | ``` 44 | 45 | 2) Run `generate_graph.py` to create the spatial-temporal graphs from the features: 46 | ``` 47 | python generate_graph.py --feature resnet18-tsm-aug 48 | ``` 49 | Although this script takes some time to finish in its current form, it can be modified to run in parallel and create the graphs for multiple videos at once. For example, you can change the `files` variable in line 81 of `data_loader.py`. 50 | 51 | 3) Use `train_val.py` to train and evaluate the model: 52 | ``` 53 | python train_val.py --feature resnet18-tsm-aug 54 | ``` 55 | You can change the `--feature` argument to `resnet50-tsm-aug` for SPELL with ResNet-50-TSM. 56 | 57 | ## Note 58 | - We used the official code of [Active Speakers in Context (ASC)](https://github.com/fuankarion/active-speakers-context) to extract the audio-visual features (Stage-1). Specifically, we used `STE_train.py` and `STE_forward.py` of the ASC repository to train our two-stream ResNet-TSM encoders and extract the audio-visual features. We did not use any other components such as the postprocessing module or the context refinement modules. Please refer to `models_stage1_tsm.py` and the checkpoints from this [link](https://drive.google.com/drive/folders/1-EiPau0uzRA7pesuD5D-f6LZD6mxmYhz?usp=sharing) to see how we implanted the TSM into the two-stream ResNets. 59 | 60 | ## Citation 61 | ECCV 2022 paper: 62 | ```bibtex 63 | @inproceedings{min2022learning, 64 | title={Learning Long-Term Spatial-Temporal Graphs for Active Speaker Detection}, 65 | author={Min, Kyle and Roy, Sourya and Tripathi, Subarna and Guha, Tanaya and Majumdar, Somdeb}, 66 | booktitle={European Conference on Computer Vision}, 67 | pages={371--387}, 68 | year={2022}, 69 | organization={Springer} 70 | } 71 | ``` 72 | 73 | Technical report for AVA-ActiveSpeaker challenge 2022: 74 | ```bibtex 75 | @article{minintel, 76 | title={Intel Labs at ActivityNet Challenge 2022: SPELL for Long-Term Active Speaker Detection}, 77 | author={Min, Kyle and Roy, Sourya and Tripathi, Subarna and Guha, Tanaya and Majumdar, Somdeb}, 78 | journal={The ActivityNet Large-Scale Activity Recognition Challenge}, 79 | year={2022}, 80 | note={\url{https://research.google.com/ava/2022/S2_SPELL_ActivityNet_Challenge_2022.pdf}} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import argparse 6 | from models_gnn import SPELL 7 | from data_loader import AVADataset 8 | from torch_geometric.loader import DataLoader 9 | from sklearn.metrics import average_precision_score 10 | 11 | 12 | parser = argparse.ArgumentParser(description='SPELL') 13 | parser.add_argument('--gpu_id', type=int, default=0, help='which gpu to run the train_val') 14 | parser.add_argument('--feature', type=str, default='resnet18-tsm-aug', help='name of the features') 15 | parser.add_argument('--numv', type=int, default=2000, help='number of nodes (n in our paper)') 16 | parser.add_argument('--time_edge', type=float, default=0.9, help='time threshold (tau in our paper)') 17 | parser.add_argument('--cross_identity', type=str, default='cin', help='whether to allow cross-identity edges') 18 | parser.add_argument('--edge_weight', type=str, default='fsimy', help='how to decide edge weights') 19 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') # 5e-4 or 1e-3 works well 20 | parser.add_argument('--sch_param', type=int, default=100, help='parameter for lr scheduler') # 10 or 100 21 | parser.add_argument('--channel1', type=int, default=64, help='filter dimension of GCN layers (layer1-2)') 22 | parser.add_argument('--channel2', type=int, default=16, help='filter dimension of GCN layers (layer2-3)') 23 | parser.add_argument('--proj_dim', type=int, default=64, help='projection of 4->proj_dim for spatial feature') 24 | parser.add_argument('--batch_size', type=int, default=16, help='batch size') 25 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout for SAGEConv') # 0.2 ~ 0.4 26 | parser.add_argument('--dropout_a', type=float, default=0, help='dropout value for dropout_adj') 27 | parser.add_argument('--da_true', action='store_true', help='always apply dropout_adj for both the training and testing') 28 | parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility') 29 | parser.add_argument('--num_epoch', type=int, default=70, help='total number of epochs') 30 | parser.add_argument('--eval_freq', type=int, default=1, help='how frequently run the evaluation') 31 | 32 | 33 | def main(): 34 | args = parser.parse_args() 35 | 36 | np.random.seed(args.seed) 37 | random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | torch.backends.cudnn.deterministic = True 40 | 41 | graph_data = {} 42 | graph_data['numv'] = args.numv 43 | graph_data['skip'] = graph_data['numv'] 44 | graph_data['time_edge'] = args.time_edge 45 | graph_data['cross_identity'] = args.cross_identity 46 | graph_data['edge_weight'] = args.edge_weight 47 | 48 | # path of the audio-visual features 49 | dpath_root = os.path.join('features', '{}_features'.format(args.feature)) 50 | 51 | # path of the generated graphs 52 | exp_key = '{}_{}_{}_{}_{}'.format(args.feature, graph_data['numv'], graph_data['time_edge'], graph_data['cross_identity'], graph_data['edge_weight']) 53 | tpath_root = os.path.join('graphs', exp_key) 54 | 55 | # path for the results and model checkpoints 56 | exp_name = '{}_lr{}-{}_c{}-{}_d{}-{}_s{}'.format(exp_key, args.lr, args.sch_param, args.channel1, args.channel2, args.dropout, args.dropout_a, args.seed) 57 | 58 | print (exp_name) 59 | 60 | result_path = os.path.join('results', exp_name) 61 | os.makedirs(result_path, exist_ok=True) 62 | 63 | dpath_train = os.path.join(dpath_root, 'train_forward', '*.csv') 64 | tpath_train = os.path.join(tpath_root, 'train') 65 | dpath_val = os.path.join(dpath_root, 'val_forward', '*.csv') 66 | tpath_val = os.path.join(tpath_root, 'val') 67 | 68 | cont = 1 69 | Fdataset_train = AVADataset(dpath_train, graph_data, cont, tpath_train, mode='train') 70 | Fdataset_val = AVADataset(dpath_val, graph_data, cont, tpath_val, mode='val') 71 | 72 | train_loader = DataLoader(Fdataset_train, batch_size=args.batch_size, shuffle=True, num_workers=4) 73 | val_loader = DataLoader(Fdataset_val, batch_size=1, shuffle=False, num_workers=4) 74 | 75 | # gpu and learning parameter settings 76 | feature_dim = 1024 77 | if 'resnet50' in args.feature: 78 | feature_dim = 4096 79 | 80 | device = ('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu') 81 | model = SPELL([args.channel1, args.channel2], feature_dim, args.dropout, args.dropout_a, args.da_true, proj_dim=args.proj_dim) 82 | model.to(device) 83 | 84 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 85 | criterion = torch.nn.BCELoss() 86 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.sch_param) 87 | 88 | flog = open(os.path.join(result_path, 'log.txt'), mode = 'w') 89 | max_mAP = 0 90 | for epoch in range(1, args.num_epoch+1): 91 | loss = train(model, train_loader, device, optimizer, criterion, scheduler) 92 | str_print = '[{:3d}|{:3d}]: Training loss: {:.4f}'.format(epoch, args.num_epoch, loss) 93 | 94 | if epoch % args.eval_freq == 0: 95 | mAP = evaluation(model, val_loader, device, feature_dim) 96 | if mAP > max_mAP: 97 | max_mAP = mAP 98 | epoch_max = epoch 99 | torch.save(model.state_dict(), os.path.join(result_path, 'chckpoint_{:03d}.pt'.format(epoch))) 100 | 101 | str_print += ', mAP: {:.4f} (max_mAP: {:.4f} at epoch: {})'.format(mAP, max_mAP, epoch_max) 102 | 103 | print (str_print) 104 | flog.write(str_print+'\n') 105 | flog.flush() 106 | 107 | flog.close() 108 | 109 | 110 | def train(model, train_loader, device, optimizer, criterion, scheduler): 111 | model.train() 112 | loss_sum = 0. 113 | 114 | for data in train_loader: 115 | data = data.to(device) 116 | optimizer.zero_grad() 117 | 118 | output = model(data) 119 | 120 | loss = criterion(output, data.y) 121 | loss.backward() 122 | loss_sum += loss.item() 123 | optimizer.step() 124 | 125 | scheduler.step() 126 | 127 | return loss_sum/len(train_loader) 128 | 129 | 130 | def evaluation(model, val_loader, device, feature_dim): 131 | model.eval() 132 | target_total = [] 133 | soft_total = [] 134 | #stamp_total = [] 135 | 136 | with torch.no_grad(): 137 | for data in val_loader: 138 | data = data.to(device) 139 | x = data.x 140 | y = data.y 141 | 142 | scores = model(data) 143 | scores = scores[:, 0].tolist() 144 | preds = [1.0 if i >= 0.5 else 0.0 for i in scores] 145 | 146 | soft_total.extend(scores) 147 | target_total.extend(y[:, 0].tolist()) 148 | #stamp_total.extend(x[:, feature_dim+4:].tolist()) # you can use the stamps to make the results in the official ActivityNet format 149 | 150 | # it does not produce an official mAP score (but the difference is negligible) 151 | # we report the scores computed by an official evaluation script by ActivityNet in our paper 152 | mAP = average_precision_score(target_total, soft_total) 153 | 154 | return mAP 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import glob 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | from torch_geometric.data import Data 9 | from torch_geometric.data import Dataset 10 | from torch_geometric.data.makedirs import makedirs 11 | 12 | def files_exist(files) -> bool: 13 | # NOTE: We return `False` in case `files` is empty, leading to a 14 | # re-processing of files on every instantiation. 15 | return len(files) != 0 and all([os.path.exists(f) for f in files]) 16 | 17 | ## this class is used both as a dataloader for training the GNN and for constructing the graph data 18 | ## if parameter cont==1, it assumes the dataset already exists and samples from the datset path during training 19 | ## during graph generation phase cont is set any other value except 1 (e.g. 0) 20 | class AVADataset(Dataset): 21 | def __init__(self, dpath, graph_data, cont, root, mode = 'train'): 22 | # parsing graph paramaters-------------------------- 23 | self.dpath = dpath 24 | self.numv = graph_data['numv'] 25 | self.skip = graph_data['skip'] 26 | self.cont = cont 27 | self.time_edge = graph_data['time_edge'] 28 | self.cross_identity = graph_data['cross_identity'] 29 | self.edge_weight = graph_data['edge_weight'] 30 | self.mode = mode 31 | #--------------------------------------------------- 32 | 33 | super(AVADataset, self).__init__(root) 34 | self.all_files = self.processed_file_names 35 | 36 | @property 37 | def raw_file_names(self): 38 | return [] 39 | 40 | @property 41 | ### this function is used to name the graphs when cont!=1; 42 | ### when cont==1 this function simply reads the names of processed graphs from 'self.processed_dir' 43 | def processed_file_names(self): 44 | files = glob.glob(self.dpath) 45 | files = sorted(files) 46 | 47 | files = [os.path.splitext(os.path.basename(f))[0] for f in files] 48 | if self.cont == 1: 49 | files = sorted(os.listdir(self.processed_dir)) 50 | 51 | return files 52 | 53 | def _download(self): 54 | return 55 | 56 | def _process(self): 57 | if files_exist(self.processed_paths) or files_exist([d+'_001.pt' for d in self.processed_paths]): # pragma: no cover 58 | return 59 | 60 | print('Processing...', file=sys.stderr) 61 | 62 | makedirs(self.processed_dir) 63 | self.process() 64 | 65 | print('Done!', file=sys.stderr) 66 | 67 | def process(self): 68 | files = glob.glob(self.dpath) 69 | files = sorted(files) 70 | 71 | id_dict = {} 72 | vstamp_dict = {} 73 | id_ct = 0 74 | ustamp = 0 75 | 76 | dict_vte_spe = {} 77 | with open('csv_files/ava_activespeaker_{}.csv'.format(self.mode)) as f: 78 | reader = csv.reader(f) 79 | data_gt = list(reader) 80 | 81 | for video_id, frame_timestamp, x1, y1, x2, y2, label, entity_id in data_gt: 82 | if video_id == 'video_id': 83 | continue 84 | vte = (video_id, float(frame_timestamp), entity_id) 85 | x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2) 86 | if vte not in dict_vte_spe: 87 | dict_vte_spe[vte] = [(x1+x2)/2, (y1+y2)/2, x2-x1, y2-y1] 88 | 89 | ## iterating over videos(features) in training/validation set 90 | for ct, fl in enumerate(files): 91 | if self.cont == 1: 92 | continue 93 | 94 | ## load the current feature csv file 95 | with open(fl, newline='') as f: 96 | reader = csv.reader(f) 97 | data_f = list(reader) 98 | 99 | #------Note-------------------- 100 | ## data_f contains the feature data of the current video 101 | ## the format is the following: Each row of data_f is a list itself and corresponds to a face-box 102 | ## format of data_f: For any row=i, data_f[i][0]=video_id, data_f[i][1]=time_stamp, data_f[i][2]=entity_id, data_f[i][3]= facebox's label, data_f[i][-1]=facebox feature 103 | #------------ 104 | 105 | # we sort the rows by their time-stamps 106 | data_f.sort(key = lambda x: float(x[1])) 107 | 108 | num_v = self.numv 109 | count_gp = 1 110 | len_data = len(data_f) 111 | 112 | # iterating over blocks of face-boxes(or the rows) of the current feature file 113 | for i in tqdm(range(0, len_data, self.skip)): 114 | if os.path.isfile(self.processed_paths[ct]+ '_{}.pt'.format(count_gp)): 115 | print('skipped') 116 | continue 117 | 118 | ## in pygeometric edges are stored in source-target/directed format ,i.e, for us (source_vertices[i], source_vertices[i]) is an edge for all i 119 | source_vertices = [] 120 | target_vertices = [] 121 | 122 | # x is the list to store the vertex features ; x[i,:] is the feature of the i-th vertex 123 | x = [] 124 | # y is the list to store the vertex labels ; y[i] is the label of the i-th vertex 125 | y = [] 126 | # identity and times are two lists keep track of idenity and time stamp of the current vertex 127 | identity = [] 128 | times = [] 129 | 130 | unique_id = [] 131 | 132 | ##------------------------------ 133 | ## this block computes the index of the start facebox and the last 134 | if i+num_v <= len_data: 135 | start_g = i 136 | end_g = i+num_v 137 | else: 138 | print ("i is'", i) 139 | start_g = i 140 | end_g = len_data 141 | ##-------------------------------------- 142 | 143 | ### we go over the face-boxes of the current partition and construct their edges, collect their features within this for loop 144 | for j in range(start_g, end_g): 145 | #----------------------------------------------- 146 | # optional 147 | # note: often we might want to have global identity or 148 | stamp_marker = data_f[j][1] + data_f[j][0] 149 | id_marker = data_f[j][2] + str(ct) 150 | 151 | if stamp_marker not in vstamp_dict: 152 | vstamp_dict[stamp_marker] = ustamp 153 | ustamp = ustamp + 1 154 | 155 | if id_marker not in id_dict: 156 | id_dict[id_marker] = id_ct 157 | id_ct = id_ct + 1 158 | #--------------------------------------------- 159 | 160 | vte = (data_f[j][0], float(data_f[j][1]), data_f[j][2]) 161 | 162 | ## parse the current facebox's feature from data_f 163 | feat = self.decode_feature(data_f[j][-1]) 164 | 165 | # append feature vector to the list of facebox(or vertex) features 166 | ## in additiona to the A-V feature, we can append additional information to the feature vector for later usage like time-stamp 167 | tail = [] 168 | tail.extend(dict_vte_spe[vte]) 169 | tail.extend([id_dict[data_f[j][2]+str(ct)], vstamp_dict[stamp_marker]]) 170 | feat = np.append(feat, tail) 171 | feat = np.expand_dims(feat, axis=0) 172 | 173 | x.append(feat) 174 | 175 | #append i-th vertex label 176 | y.append(float(data_f[j][3])) 177 | 178 | ## append time and identity of i-th vertex to the list of time stamps and identitites 179 | times.append(float(data_f[j][1])) 180 | identity.append(data_f[j][2]) 181 | 182 | edge_attr = [] 183 | num_edge = 0 184 | 185 | ## iterating over pairs of vertices of the current partition and assign edges accodring to some criterion 186 | for j in range(0, end_g - start_g): 187 | for k in range(0, end_g - start_g): 188 | 189 | if self.cross_identity == 'cin': 190 | id_cond = identity[j]==identity[k] 191 | else: 192 | id_cond = True 193 | 194 | # time difference between j-th and k-th vertex 195 | time_gap = times[j]-times[k] 196 | 197 | if 0