├── README.md
├── data
├── episode_summary.csv
├── scene_summary.csv
└── scenes_descriptions.csv
├── environment.yml
├── fusion_MW.py
├── fusion_data_sample.py
├── fusion_main_train.py
├── images
└── model.png
├── license.txt
├── multi_head_attention.py
├── stream_data_sample.py
├── stream_main_train.py
├── train_stream.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ## On the hidden treasure of dialog in video question answering
2 |
3 | [Deniz Engin](https://engindeniz.github.io/), [François Schnitzler](https://sites.google.com/site/francoisschnitzler/), [Ngoc Q. K. Duong](https://www.interdigital.com/talent/?id=88) and [Yannis Avrithis](https://avrithis.net/), On the hidden treasure of dialog in video question answering, ICCV 2021.
4 |
5 | [Project page](https://engindeniz.github.io/dialogsummary-videoqa) | [arXiv](https://arxiv.org/abs/2103.14517)
6 |
7 | ---
8 | ### Model Overview
9 | 
10 |
11 |
12 | Our VideoQA system converts dialog and video inputs to episode dialog summaries and video descriptions, respectively. Converted inputs and dialog are processed independently in streams, along with the question and each answer,
13 | producing a score per answer. Finally, stream embeddings are fused separately per answer and a prediction is made.
14 |
15 |
16 | ### Environment Setup
17 |
18 | To create a conda environment:
19 | ````
20 | conda env create -f environment.yml
21 | ````
22 |
23 | or
24 |
25 | ````
26 | conda create --name dialog-videoqa python=3.6
27 | conda activate dialog-videoqa
28 | conda install -c anaconda numpy pandas scikit-learn
29 | conda install -c conda-forge tqdm
30 | conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=10.0 -c pytorch
31 | pip install pytorch-transformers
32 | ````
33 |
34 | ### Data preparation
35 | * Download [KnowIT VQA](https://knowit-vqa.github.io/) dataset and extract in [data folder](data).
36 | * Extracted scene and episode dialog summaries are provided as separate files in [data folder](data).
37 | * Plot summaries used in [ROLL-VideoQA](https://arxiv.org/pdf/2007.08751.pdf), they can be download from [here](https://github.com/noagarcia/ROLL-VideoQA/blob/master/Data/knowledge_base/tbbt_summaries.csv).
38 | * Scene descriptions are obtained by following [ROLL-VideoQA](https://github.com/noagarcia/ROLL-VideoQA), generated descriptions are provided in [data folder](data).
39 |
40 | ### Training Models
41 |
42 | This section explains single-stream QA and multi-stream QA trainings.
43 |
44 | ##### Single-Stream QA
45 |
46 | Our main streams are video, scene dialog summary, episode dialog summary. Dialog and plot streams are used for comparison. All stream trainings as follows:
47 |
48 | Training video stream:
49 | ```
50 | python stream_main_train.py --train_name video --max_seq_length 512
51 | ```
52 | Training scene dialog summary stream:
53 | ```
54 | python stream_main_train.py --train_name scene_dialog_summary --max_seq_length 512
55 | ```
56 | Training episode dialog summary stream:
57 | ```
58 | python stream_main_train.py --train_name episode_dialog_summary --max_seq_length 300 --seq_stride 200 --mini_batch_size 2 --eval_batch_size 16
59 | ```
60 | Training dialog stream:
61 | ```
62 | python stream_main_train.py --train_name dialog --max_seq_length 512
63 |
64 | ```
65 | Training plot stream:
66 | ```
67 | python stream_main_train.py --train_name plot --max_seq_length 200 --seq_stride 100 --mini_batch_size 2
68 | ```
69 |
70 | All single stream models trained on 2 Tesla V100 GPUs (32 GB) except plot trained on 1 Tesla V100.
71 | Gradient accumulation is used to fit into memory when training parameters have a "mini_batch_size".
72 |
73 | #### Multi-Stream QA
74 |
75 | Our main proposed model uses video, scene dialog summary and episode dialog summary streams for multi-stream attention method.
76 | ```
77 | python fusion_main_train.py --fuse_stream_list video scene_dialog_summary episode_dialog_summary --fusion_method multi-stream-attention
78 | ```
79 |
80 | ### License
81 |
82 | Please check the [license file](license.txt) for more information.
83 |
84 | ### Acknowledgments
85 |
86 | The code is written based on ROLL-VideoQA.
87 |
88 | ### Citation
89 |
90 | If this code is helpful for you, please cite the following:
91 |
92 | ````
93 | @inproceedings{engin2021hidden,
94 | title={On the hidden treasure of dialog in video question answering},
95 | author={Engin, Deniz and Schnitzler, Fran{\c{c}}ois and Duong, Ngoc QK and Avrithis, Yannis},
96 | journal={ICCV},
97 | year={2021}
98 | }
99 |
100 | ````
101 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dialog-videoqa
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=1_gnu
10 | - blas=1.0=mkl
11 | - ca-certificates=2021.5.30=ha878542_0
12 | - certifi=2021.5.30=py36h5fab9bb_0
13 | - cffi=1.14.5=py36hc120d54_0
14 | - cudatoolkit=10.0.130=hf841e97_8
15 | - freetype=2.10.4=h0708190_1
16 | - intel-openmp=2020.2=254
17 | - jbig=2.1=h7f98852_2003
18 | - joblib=0.17.0=py_0
19 | - jpeg=9d=h36c2ea0_0
20 | - lcms2=2.12=hddcbb42_0
21 | - ld_impl_linux-64=2.35.1=hea4e1c9_2
22 | - lerc=2.2.1=h9c3ff4c_0
23 | - libdeflate=1.7=h7f98852_5
24 | - libffi=3.3=h58526e2_2
25 | - libgcc-ng=9.3.0=h2828fa1_19
26 | - libgfortran-ng=7.3.0=hdf63c60_0
27 | - libgomp=9.3.0=h2828fa1_19
28 | - libpng=1.6.37=h21135ba_2
29 | - libstdcxx-ng=9.3.0=h6de172a_19
30 | - libtiff=4.3.0=hf544144_1
31 | - libwebp-base=1.2.0=h7f98852_2
32 | - lz4-c=1.9.3=h9c3ff4c_0
33 | - mkl=2019.4=243
34 | - mkl-service=2.3.0=py36he904b0f_0
35 | - mkl_fft=1.2.0=py36h23d657b_0
36 | - mkl_random=1.0.4=py36hd81dba3_0
37 | - ncurses=6.2=h58526e2_4
38 | - ninja=1.10.2=h4bd325d_0
39 | - numpy=1.19.1=py36hbc911f0_0
40 | - numpy-base=1.19.1=py36hfa32c7d_0
41 | - olefile=0.46=pyh9f0ad1d_1
42 | - openjpeg=2.4.0=hb52868f_1
43 | - openssl=1.1.1k=h7f98852_0
44 | - pandas=1.1.3=py36he6710b0_0
45 | - pillow=8.2.0=py36ha6010c0_1
46 | - pip=21.1.2=pyhd8ed1ab_0
47 | - pycparser=2.20=pyh9f0ad1d_2
48 | - python=3.6.13=hffdb5ce_0_cpython
49 | - python-dateutil=2.8.1=py_0
50 | - python_abi=3.6=1_cp36m
51 | - pytorch=1.0.1=py3.6_cuda10.0.130_cudnn7.4.2_2
52 | - pytz=2020.1=py_0
53 | - readline=8.1=h46c0cb4_0
54 | - scikit-learn=0.23.2=py36h0573a6f_0
55 | - scipy=1.5.2=py36h0b6359f_0
56 | - setuptools=49.6.0=py36h5fab9bb_3
57 | - six=1.15.0=py_0
58 | - sqlite=3.35.5=h74cdb3f_0
59 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
60 | - tk=8.6.10=h21135ba_1
61 | - torchvision=0.2.2=py_3
62 | - tqdm=4.61.1=pyhd8ed1ab_0
63 | - wheel=0.36.2=pyhd3deb0d_0
64 | - xz=5.2.5=h516909a_1
65 | - zlib=1.2.11=h516909a_1010
66 | - zstd=1.5.0=ha95c52a_0
67 | - pip:
68 | - boto3==1.17.94
69 | - botocore==1.20.94
70 | - chardet==4.0.0
71 | - click==8.0.1
72 | - idna==2.10
73 | - importlib-metadata==4.5.0
74 | - jmespath==0.10.0
75 | - pytorch-transformers==1.2.0
76 | - regex==2021.4.4
77 | - requests==2.25.1
78 | - s3transfer==0.4.2
79 | - sacremoses==0.0.45
80 | - sentencepiece==0.1.95
81 | - typing-extensions==3.10.0.0
82 | - urllib3==1.26.5
83 | - zipp==3.4.1
--------------------------------------------------------------------------------
/fusion_MW.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Code by Noa Garcia and Yuta Nakashima"""
3 | import argparse
4 | import json
5 | import logging
6 | import os
7 | import random
8 | import time
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import torch
13 | from torch import nn
14 | from torch.optim.lr_scheduler import ReduceLROnPlateau
15 | from torch.utils.data import DataLoader
16 |
17 | import utils
18 | from fusion_data_sample import FusionDataSample
19 | from utils import create_folder_with_timestamp, str2bool
20 |
21 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
22 | datefmt='%m/%d/%Y %H:%M:%S',
23 | level=logging.INFO)
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def get_params():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--data_dir", default='data/', type=str)
30 | parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
31 | parser.add_argument("--do_lower_case", default=True)
32 | parser.add_argument('--seed', type=int, default=181)
33 | parser.add_argument("--lr", default=0.0001, type=float)
34 | parser.add_argument("--workers", default=8)
35 | parser.add_argument("--device", default='cuda', type=str, help="cuda, cpu")
36 | parser.add_argument("--batch_size", default=32, type=int)
37 | parser.add_argument('--momentum', default=0.9)
38 | parser.add_argument('--nepochs', default=100, help='Number of epochs', type=int)
39 | parser.add_argument('--patience', default=15, type=int)
40 | parser.add_argument('--no_cuda', action='store_true')
41 | parser.add_argument("--num_max_slices_plot", default=10, type=int)
42 | parser.add_argument("--num_max_slices_episode_dialog_summary", default=10, type=int)
43 | parser.add_argument('--weight_loss_final', default=0.7, type=float)
44 |
45 | """Code by InterDigital"""
46 | parser.add_argument('--fuse_stream_list', nargs='+', required=True, type=str)
47 | parser.add_argument('--fuse_loss_weight_list', nargs='+', default=None, type=float)
48 | parser.add_argument("--load_pretrained_model_exists", default=False)
49 | parser.add_argument("--eval_split", default="test", type=str)
50 | parser.add_argument('--lr_patience', default=5, type=int)
51 | parser.add_argument("--stream_train_folder_path", default='Training/main_stream_trainings', type=str)
52 | parser.add_argument("--fusion_train_folder_path", default='Training/fusion', type=str)
53 | parser.add_argument("--part_selection_with_soft_temporal_attention", default=True, type=str2bool)
54 | parser.add_argument('--ss_max_temperature', default=2, type=int)
55 | args, unknown = parser.parse_known_args()
56 | return args
57 |
58 |
59 | class FusionMW(nn.Module):
60 | def __init__(self, args):
61 | self.args = args
62 | super(FusionMW, self).__init__()
63 |
64 | number_of_streams = len(args.fuse_loss_weight_list)
65 | self.module_list = nn.ModuleList([nn.Sequential(nn.Linear(768, 1)) for _ in range(number_of_streams)])
66 |
67 | self.dropout = nn.Dropout(0.5)
68 | self.classifier = nn.Sequential(nn.Linear(number_of_streams, 1))
69 |
70 | def forward(self, inputs):
71 | assert len(self.module_list) == len(inputs)
72 | num_choices = inputs[0].shape[1]
73 | reshaped_scores_list = []
74 | score_list = []
75 | for module_per_stream, input_per_stream in zip(self.module_list, inputs):
76 | flat_in = input_per_stream.view(-1, input_per_stream.size(-1))
77 | score = module_per_stream(self.dropout(flat_in))
78 | score_list.append(score)
79 | """Code by Noa Garcia and Yuta Nakashima"""
80 | reshaped_scores_list.append(score.view(-1, num_choices))
81 |
82 | # Final score
83 | all_feat = torch.squeeze(torch.cat(score_list, 1), 1)
84 | final_scores = self.classifier(all_feat)
85 | reshaped_final_scores = final_scores.view(-1, num_choices)
86 | return reshaped_scores_list, reshaped_final_scores
87 |
88 |
89 | def trainEpoch(args, train_loader, model, criterion, optimizer, epoch):
90 | losses = utils.AverageMeter()
91 | model.train()
92 | targets = []
93 | outs = []
94 | for batch_idx, (input, target) in enumerate(train_loader):
95 |
96 | # Inputs to Variable type
97 | input_var = list()
98 | for j in range(len(input)):
99 | input_var.append(torch.autograd.Variable(input[j]).cuda())
100 |
101 | # Targets to Variable type
102 | target_var = list()
103 | for j in range(len(target)):
104 | target[j] = target[j].cuda(async=True)
105 | target_var.append(torch.autograd.Variable(target[j]))
106 |
107 | # Output of the model
108 | output, final_scores = model(input_var)
109 |
110 | # Compute loss
111 | final_loss = criterion(final_scores, target_var[0])
112 | train_loss = 0
113 |
114 | """Code by InterDigital"""
115 | for idx in range(len(output)):
116 | stream_loss = criterion(output[idx], target_var[0])
117 | # Track loss
118 | train_loss += args.fuse_loss_weight_list[idx] * stream_loss
119 | train_loss += final_loss * args.weight_loss_final
120 |
121 | """Code by Noa Garcia and Yuta Nakashima"""
122 | losses.update(train_loss.data.cpu().numpy(), input[0].size(0))
123 |
124 | # for plot
125 | outs.append(torch.max(final_scores, 1)[1].data.cpu().numpy())
126 | targets.append(target[0].cpu().numpy())
127 |
128 | # Backpropagate loss and update weights
129 | optimizer.zero_grad()
130 | train_loss.backward()
131 | optimizer.step()
132 |
133 | # Print info
134 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
135 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
136 | epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss=losses))
137 |
138 | outs = np.concatenate(outs).flatten()
139 | targets = np.concatenate(targets).flatten()
140 |
141 | acc = np.sum(outs == targets) / len(outs)
142 |
143 | return epoch, losses.avg, acc, None
144 |
145 |
146 | def valEpoch(args, val_loader, model, criterion, epoch):
147 | losses = utils.AverageMeter()
148 | model.eval()
149 | for batch_idx, (input, target) in enumerate(val_loader):
150 |
151 | # Inputs to Variable type
152 | input_var = list()
153 | for j in range(len(input)):
154 | input_var.append(torch.autograd.Variable(input[j]).cuda())
155 |
156 | # Targets to Variable type
157 | target_var = list()
158 | for j in range(len(target)):
159 | target[j] = target[j].cuda(async=True)
160 | target_var.append(torch.autograd.Variable(target[j]))
161 |
162 | # Output of the model
163 | with torch.no_grad():
164 | output, final_scores = model(input_var)
165 |
166 | # Compute loss
167 | predicted = torch.max(final_scores, 1)[1]
168 |
169 | stream_predictions = [torch.max(p, 1)[1] for p in output]
170 |
171 | final_loss = criterion(final_scores, target_var[0]) * args.weight_loss_final
172 | train_loss = 0
173 |
174 | """Code by InterDigital"""
175 | for idx in range(len(output)):
176 | weighted_stream_loss = args.fuse_loss_weight_list[idx] * criterion(output[idx], target_var[0])
177 | train_loss += weighted_stream_loss
178 | train_loss += final_loss
179 |
180 | losses.update(train_loss.data.cpu().numpy(), input[0].size(0))
181 |
182 | # Save predictions to compute accuracy
183 | if batch_idx == 0:
184 | out = predicted.data.cpu().numpy()
185 | out_stream_list = []
186 | for p in stream_predictions:
187 | out_stream_list.append(p.data.cpu().numpy())
188 | label = target[0].cpu().numpy()
189 | else:
190 | out = np.concatenate((out, predicted.data.cpu().numpy()), axis=0)
191 | label = np.concatenate((label, target[0].cpu().numpy()), axis=0)
192 | for idx in range(len(stream_predictions)):
193 | out_stream_list[idx] = np.concatenate(
194 | (out_stream_list[idx], stream_predictions[idx].data.cpu().numpy()), axis=0)
195 |
196 |
197 | """Code by Noa Garcia and Yuta Nakashima"""
198 | # Accuracy
199 | acc = np.sum(out == label) / len(out)
200 | logger.info('Validation set: Average loss: {:.4f}\t'
201 | 'Accuracy {acc}'.format(losses.avg, acc=acc))
202 |
203 | logger.info('Acc Streams: %s' % [a + ": " + str(b) for a, b in
204 | zip(args.fuse_stream_list,
205 | [(np.sum(o == label) / len(o)) for o in out_stream_list])])
206 |
207 | return epoch, losses.avg, acc, None
208 |
209 |
210 | def train(args, modeldir):
211 | # Set GPU
212 | n_gpu = torch.cuda.device_count()
213 | logger.info("device: {} n_gpu: {}".format(args.device, n_gpu))
214 | random.seed(args.seed)
215 | np.random.seed(args.seed)
216 | torch.manual_seed(args.seed)
217 | if n_gpu > 0:
218 | torch.cuda.manual_seed_all(args.seed)
219 |
220 | # Model, optimizer and loss
221 |
222 | model = FusionMW(args)
223 | if args.device == "cuda":
224 | model.cuda()
225 | print(model)
226 |
227 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
228 | scheduler = ReduceLROnPlateau(optimizer, patience=args.lr_patience)
229 | class_loss = nn.CrossEntropyLoss().cuda()
230 |
231 | # Data
232 | trainDataObject = FusionDataSample(args, split='train')
233 | valDataObject = FusionDataSample(args, split='val')
234 | train_loader = torch.utils.data.DataLoader(trainDataObject, batch_size=args.batch_size, shuffle=True,
235 | pin_memory=True, num_workers=args.workers)
236 | val_loader = torch.utils.data.DataLoader(valDataObject, batch_size=args.batch_size, shuffle=True, pin_memory=True,
237 | num_workers=args.workers)
238 |
239 | logger.info('Training loader with %d samples' % train_loader.__len__())
240 | logger.info('Validation loader with %d samples' % val_loader.__len__())
241 | logger.info('Training...')
242 | pattrack = 0
243 | best_val = 0
244 |
245 | for epoch in range(0, args.nepochs):
246 |
247 | trainEpoch(args, train_loader, model, class_loss, optimizer, epoch)
248 |
249 | epoch_plot_val, loss_plot_val, acc_plot_val, stream_losses_plot_val = valEpoch(args, val_loader, model,
250 | class_loss, epoch)
251 |
252 | current_val = acc_plot_val
253 |
254 | scheduler.step(loss_plot_val)
255 |
256 | # Check patience
257 | is_best = current_val > best_val
258 | best_val = max(current_val, best_val)
259 | if not is_best:
260 | pattrack += 1
261 | else:
262 | pattrack = 0
263 | if pattrack >= args.patience:
264 | break
265 |
266 | logger.info('** Validation information: %f (this accuracy) - %f (best accuracy) - %d (patience valtrack)' % (
267 | current_val, best_val, pattrack))
268 |
269 | # Save
270 | state = {'state_dict': model.state_dict(),
271 | 'best_val': best_val,
272 | 'optimizer': optimizer.state_dict(),
273 | 'pattrack': pattrack,
274 | 'curr_val': current_val}
275 | filename = os.path.join(modeldir, 'model_latest.pth.tar')
276 | torch.save(state, filename)
277 | if is_best:
278 | filename = os.path.join(modeldir, 'model_best.pth.tar')
279 | torch.save(state, filename)
280 |
281 |
282 | def evaluate(args, modeldir):
283 | model = FusionMW(args)
284 | if args.device == "cuda":
285 | model.cuda()
286 | logger.info("=> loading checkpoint from '{}'".format(modeldir))
287 | checkpoint = torch.load(os.path.join(modeldir, 'model_best.pth.tar'))
288 | model.load_state_dict(checkpoint['state_dict'])
289 |
290 | # Data
291 | evalDataObject = FusionDataSample(args, split=args.eval_split)
292 | test_loader = torch.utils.data.DataLoader(evalDataObject, batch_size=args.batch_size, shuffle=False,
293 | pin_memory=(not args.no_cuda), num_workers=args.workers)
294 | logger.info('Evaluation loader with %d samples' % test_loader.__len__())
295 |
296 | # Switch to evaluation mode & compute test samples embeddings
297 | batch_time = utils.AverageMeter()
298 | end = time.time()
299 | model.eval()
300 | for i, (input, target) in enumerate(test_loader):
301 |
302 | # Inputs to Variable type
303 | input_var = list()
304 | for j in range(len(input)):
305 | input_var.append(torch.autograd.Variable(input[j]).cuda())
306 |
307 | # Targets to Variable type
308 | target_var = list()
309 | for j in range(len(target)):
310 | target[j] = target[j].cuda(async=True)
311 | target_var.append(torch.autograd.Variable(target[j]))
312 |
313 | # Output of the model
314 | with torch.no_grad():
315 | output, final_scores = model(input_var)
316 | # Compute final loss
317 | predicted = torch.max(final_scores, 1)[1]
318 |
319 | # measure elapsed time
320 | batch_time.update(time.time() - end)
321 | end = time.time()
322 |
323 | # Store outputs
324 | if i == 0:
325 | out = predicted.data.cpu().numpy()
326 | label = target[0].cpu().numpy()
327 | index = target[1].cpu().numpy()
328 |
329 | score_list = []
330 | for o in output:
331 | score_list.append(o.data.cpu().numpy())
332 | scores_final = final_scores.data.cpu().numpy()
333 | else:
334 | out = np.concatenate((out, predicted.data.cpu().numpy()), axis=0)
335 | label = np.concatenate((label, target[0].cpu().numpy()), axis=0)
336 | index = np.concatenate((index, target[1].cpu().numpy()), axis=0)
337 |
338 | for idx in range(len(score_list)):
339 | score_list[idx] = np.concatenate((score_list[idx], output[idx].data.cpu().numpy()), axis=0)
340 |
341 | scores_final = np.concatenate((scores_final, final_scores.cpu().numpy()), axis=0)
342 |
343 | df = pd.read_csv(os.path.join(args.data_dir, 'knowit_data_%s.csv' % args.eval_split), delimiter='\t')
344 |
345 | """Code by InterDigital"""
346 | logger.info("Eval on %s data from fusion final output" % args.eval_split)
347 | if args.eval_split == 'test':
348 | utils.accuracy(df, out, label, index)
349 | logger.info("Eval on %s data from streams" % args.eval_split)
350 | for o, str_stream in zip(score_list, args.fuse_stream_list):
351 | logger.info("Stream: %s" % str_stream)
352 | utils.accuracy(df, np.argmax(o, 1), label, index)
353 | else:
354 | """Code by Noa Garcia and Yuta Nakashima"""
355 | utils.accuracy_val(out, label)
356 |
357 |
358 | if __name__ == "__main__":
359 |
360 | args = get_params()
361 | """Code by InterDigital"""
362 | assert (args.fuse_loss_weight_list is not None) or (
363 | args.weight_loss_final is not None) # At least one loss weight should be given
364 |
365 | if args.weight_loss_final is None:
366 | args.weight_loss_final = 1 - sum(args.fuse_loss_weight_list)
367 | elif args.fuse_loss_weight_list is None:
368 | remaining_loss_weight = 1 - args.weight_loss_final
369 | args.fuse_loss_weight_list = [remaining_loss_weight / len(args.fuse_stream_list)] * len(args.fuse_stream_list)
370 |
371 | assert len(args.fuse_stream_list) >= 2 # Make sure at least two streams given
372 | assert len(args.fuse_stream_list) == len(
373 | args.fuse_loss_weight_list) # Make sure to give loss weight in the same amount of streams
374 | assert sum(args.fuse_loss_weight_list) + args.weight_loss_final < 1.1 # Normalize the loss two approx. 1
375 | assert sum(args.fuse_loss_weight_list) + args.weight_loss_final > 0.9
376 |
377 |
378 |
379 | model_name_path = os.path.join(args.fusion_train_folder_path, "-".join(args.fuse_stream_list)+"_"+str(args.weight_loss_final))
380 |
381 | modeldir = create_folder_with_timestamp(model_name_path,
382 | args.load_pretrained_model_exists)
383 |
384 | logger.info("Arguments: %s" % json.JSONEncoder().encode(vars(args)))
385 |
386 | with open(os.path.join(modeldir, "args.json"), 'w') as f:
387 | json.dump(vars(args), f)
388 |
389 | """Code by Noa Garcia and Yuta Nakashima"""
390 | # Train if model does not exist
391 | if not os.path.isfile(os.path.join(modeldir, 'model_best.pth.tar')):
392 | train(args, modeldir)
393 |
394 | # Evaluation
395 | evaluate(args, modeldir)
--------------------------------------------------------------------------------
/fusion_data_sample.py:
--------------------------------------------------------------------------------
1 | """Code by Noa Garcia and Yuta Nakashima"""
2 | import logging
3 | import os
4 |
5 | import numpy as np
6 | from torch.utils import data
7 |
8 | import utils
9 | from utils import SCENE_BASED_STREAMS, EPISODE_BASED_STREAMS
10 | from utils import load_knowit_data
11 | from scipy.special import softmax
12 |
13 |
14 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
15 | datefmt='%m/%d/%Y %H:%M:%S',
16 | level=logging.INFO)
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class FusionDataSample(data.Dataset):
21 |
22 | def __init__(self, args, split):
23 | df = load_knowit_data(args, split)
24 | self.labels = (df['idxCorrect'] - 1).to_list()
25 | """Code by InterDigital"""
26 | self.scene_based_features = []
27 | self.episode_based_features = []
28 | self.episode_logits_slices = []
29 | self.scene_based_stream_names = []
30 | self.episode_based_stream_names = []
31 | """Code by Noa Garcia and Yuta Nakashima"""
32 | self.args = args
33 |
34 | for stream in args.fuse_stream_list:
35 |
36 | base_embedding_path = os.path.join(args.stream_train_folder_path, stream, 'embeddings')
37 | embeddings = utils.load_obj(
38 | os.path.join(base_embedding_path, stream+'_stream_embeddings_%s.pckl' % split))
39 |
40 | if stream in SCENE_BASED_STREAMS:
41 | self.scene_based_stream_names.append(stream)
42 | self.scene_based_features.append(np.reshape(embeddings, (int(embeddings.shape[0] / 4), 4, 768)))
43 | elif stream in EPISODE_BASED_STREAMS:
44 | self.episode_based_stream_names.append(stream)
45 | episode_based_reshaped_feature = np.reshape(embeddings[0], (
46 | int(embeddings[0].shape[0] / 4), args.__dict__['num_max_slices_' + stream], 4, 768))
47 | self.episode_based_features.append(episode_based_reshaped_feature)
48 | self.episode_logits_slices.append(embeddings[1])
49 | else:
50 | raise NotImplementedError
51 |
52 | self.num_samples = len(self.labels)
53 | logger.info('Dataloader with %d samples' % self.num_samples)
54 |
55 | def __len__(self):
56 | return self.num_samples
57 |
58 | def __getitem__(self, index):
59 | label = self.labels[index]
60 | outputs = [label, index]
61 | inputs = []
62 | """Code by InterDigital"""
63 | scene_based_inputs = []
64 | episode_based_inputs = []
65 | for stream in self.scene_based_features:
66 | scene_based_inputs.append(stream[index, :])
67 | for stream, slice in zip(self.episode_based_features, self.episode_logits_slices):
68 | stream_slices = stream[index, :]
69 | stream_logits_slice = slice[index, :]
70 |
71 | if self.args.part_selection_with_soft_temporal_attention:
72 | a = np.max(stream_logits_slice, axis=1).reshape(1, -1)
73 | s = softmax(a / self.args.ss_max_temperature, axis=1)
74 | results_embeddings = np.matmul(s, stream_slices.reshape(s.shape[1], -1)).reshape(4, 768)
75 | episode_based_inputs.append(results_embeddings)
76 | else:
77 | idx_slice, _ = np.unravel_index(stream_logits_slice.argmax(), stream_logits_slice.shape)
78 | episode_based_inputs.append(stream_slices[idx_slice, :])
79 |
80 | for stream in self.args.fuse_stream_list:
81 | if stream in self.scene_based_stream_names:
82 | stream_names_index = self.scene_based_stream_names.index(stream)
83 | inputs.append(scene_based_inputs[stream_names_index])
84 | else:
85 | stream_names_index = self.episode_based_stream_names.index(stream)
86 | inputs.append(episode_based_inputs[stream_names_index])
87 |
88 | return inputs, outputs
89 |
--------------------------------------------------------------------------------
/fusion_main_train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Code by Noa Garcia and Yuta Nakashima"""
3 | import argparse
4 | import json
5 | import logging
6 | import os
7 | import random
8 | import time
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import torch
13 | from torch import nn
14 | from torch.nn import LayerNorm, Dropout
15 | from torch.optim.lr_scheduler import ReduceLROnPlateau
16 | from torch.utils.data import DataLoader
17 |
18 | import utils
19 | from fusion_data_sample import FusionDataSample
20 | from multi_head_attention import MultiHeadAttention
21 | from utils import create_folder_with_timestamp, str2bool
22 |
23 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
24 | datefmt='%m/%d/%Y %H:%M:%S',
25 | level=logging.INFO)
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | def get_params():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--data_dir", default='data/', type=str)
32 | parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
33 | parser.add_argument("--do_lower_case", default=True, type=bool)
34 | parser.add_argument('--seed', type=int, default=181)
35 | parser.add_argument("--lr", default=0.0001, type=float)
36 | parser.add_argument("--workers", default=8)
37 | parser.add_argument("--device", default='cuda', type=str, help="cuda, cpu")
38 | parser.add_argument("--batch_size", default=32, type=int)
39 | parser.add_argument('--momentum', default=0.9)
40 | parser.add_argument('--nepochs', default=100, help='Number of epochs', type=int)
41 | parser.add_argument('--patience', default=15, type=int)
42 | parser.add_argument('--no_cuda', action='store_true')
43 |
44 | """Code by InterDigital"""
45 | parser.add_argument("--num_max_slices_plot", default=10, type=int)
46 | parser.add_argument("--num_max_slices_episode_dialog_summary", default=10, type=int)
47 | parser.add_argument('--fuse_stream_list', nargs='+', required=True, type=str)
48 | parser.add_argument('--num_head', default=1, type=int)
49 | parser.add_argument("--load_pretrained_model_exists", default=False, type=bool)
50 | parser.add_argument('--ss_max_temperature', default=2, type=int)
51 | parser.add_argument('--lr_patience', default=5, type=int)
52 | parser.add_argument('--fusion_method', required=True, type=str)
53 | parser.add_argument("--fusion_train_folder_path", default='Training', type=str)
54 | parser.add_argument("--stream_train_folder_path", default='Training/', type=str)
55 | parser.add_argument("--pretrain_modeldir", type=str)
56 | parser.add_argument("--part_selection_with_soft_temporal_attention", default=True, type=str2bool)
57 | parser.add_argument("--save_multi_stream_attention_scores", default=False, type=str2bool)
58 |
59 | args, unknown = parser.parse_known_args()
60 | return args
61 |
62 | class TwoInputSequential(nn.Sequential):
63 | def forward(self, *inputs):
64 | for module in self._modules.values():
65 | if type(inputs) == tuple:
66 | inputs = module(*inputs)
67 | else:
68 | inputs = module(inputs)
69 | return inputs
70 |
71 |
72 | class MultiStreamAttention(nn.Module):
73 | def __init__(self, input_shape):
74 | super(MultiStreamAttention, self).__init__()
75 | self.layers = nn.Sequential(nn.Linear(input_shape, input_shape // 2), nn.ReLU(), nn.Dropout(0.5),
76 | nn.Linear(input_shape // 2, 1), nn.Dropout(0.5),
77 | nn.Softmax(dim=1))
78 |
79 | def forward(self, input):
80 | score = self.layers(input)
81 |
82 | input_ = input * score
83 |
84 | if save_multi_stream_scores:
85 | attention_scores.append(score)
86 |
87 | return input_
88 |
89 |
90 | class ResidualSelfAttention(nn.Module):
91 | def __init__(self, embed, num_head):
92 | super(ResidualSelfAttention, self).__init__()
93 |
94 | self.layer = MultiHeadAttention(embed, num_head)
95 | self.norm = LayerNorm(embed)
96 | self.dropout = Dropout(0.5)
97 |
98 | def forward(self, input):
99 | attended = self.layer(input, input, input)[0]
100 | output = self.norm(self.dropout(attended) + input)
101 | return output
102 |
103 |
104 | class Flatten(nn.Module):
105 | def __init__(self):
106 | super(Flatten, self).__init__()
107 |
108 | def forward(self, input):
109 | return torch.flatten(input,1)
110 |
111 |
112 | class FusionProduct(nn.Module):
113 | def __init__(self, args):
114 | self.args = args
115 | super(FusionProduct, self).__init__()
116 |
117 | self.stream_transformer_blocks = nn.ModuleList([nn.Sequential(Flatten(),
118 | nn.Linear(768, 1)) for _ in range(4)])
119 |
120 | def forward(self, inputs):
121 |
122 | num_choices = inputs[0].shape[1]
123 |
124 | reshaped_scores_list = []
125 | score_list = []
126 |
127 | for choice, module_per_answer in zip(range(num_choices), self.stream_transformer_blocks):
128 |
129 | result = torch.ones_like(inputs[0][:, 0, :])
130 |
131 | for br in inputs:
132 | result = br[:, choice, :] * result
133 |
134 | answer = module_per_answer(result)
135 | score_list.append(answer)
136 |
137 | # Final score
138 | all_feat = torch.squeeze(torch.cat(score_list, 1), 1)
139 |
140 | reshaped_final_scores = all_feat.view(-1, num_choices)
141 | return reshaped_scores_list, reshaped_final_scores
142 |
143 |
144 | class FusionMethods(nn.Module):
145 | def __init__(self, args):
146 | self.args = args
147 | super(FusionMethods, self).__init__()
148 |
149 | if args.fusion_method == 'multi-stream-attention':
150 | self.stream_transformer_blocks = nn.ModuleList([nn.Sequential(MultiStreamAttention(768), Flatten(),
151 | nn.Linear(768 * len(args.fuse_stream_list),
152 | 1)) for _ in range(4)])
153 |
154 |
155 | elif args.fusion_method == 'self-attention':
156 | self.stream_transformer_blocks = nn.ModuleList(
157 | [nn.Sequential(ResidualSelfAttention(768, args.num_head),ResidualSelfAttention(768, args.num_head),
158 | Flatten(),
159 | nn.Linear(768 * len(args.fuse_stream_list), 1)) for _ in range(4)])
160 |
161 | elif args.fusion_method == 'multi-stream-self-attention':
162 | self.stream_transformer_blocks = nn.ModuleList(
163 | [nn.Sequential(MultiStreamAttention(768), ResidualSelfAttention(768, args.num_head), Flatten(),
164 | nn.Linear(768 * len(args.fuse_stream_list), 1)) for _ in range(4)])
165 |
166 | elif args.fusion_method == 'product':
167 | pass
168 |
169 | else:
170 | raise NotImplementedError
171 |
172 | def forward(self, inputs):
173 |
174 | num_choices = inputs[0].shape[1]
175 |
176 | reshaped_scores_list = []
177 | score_list = []
178 |
179 | for choice, module_per_answer in zip(range(num_choices), self.stream_transformer_blocks):
180 |
181 | if self.args.fusion_method == 'product':
182 | result = torch.ones_like(inputs[0][:, 0, :])
183 | for br in inputs:
184 | result = br[:, choice, :] * result
185 | answer = module_per_answer(result)
186 | score_list.append(answer)
187 | else:
188 | answer_list = []
189 | for br in inputs:
190 | answer_list.append(br[:, choice, :])
191 | stack = torch.stack(answer_list, dim=1)
192 | answer = module_per_answer(stack)
193 | score_list.append(answer)
194 |
195 | # Final score
196 | all_feat = torch.squeeze(torch.cat(score_list, 1), 1)
197 |
198 | reshaped_final_scores = all_feat.view(-1, num_choices)
199 | return reshaped_scores_list, reshaped_final_scores
200 |
201 | """Code by Noa Garcia and Yuta Nakashima"""
202 | def trainEpoch(args, train_loader, model, criterion, optimizer, epoch):
203 | losses = utils.AverageMeter()
204 | model.train()
205 |
206 | targets = []
207 | outs = []
208 |
209 | for batch_idx, (input, target) in enumerate(train_loader):
210 |
211 | # Inputs to Variable type
212 | input_var = list()
213 | for j in range(len(input)):
214 | input_var.append(torch.autograd.Variable(input[j]).cuda())
215 |
216 | # Targets to Variable type
217 | target_var = list()
218 | for j in range(len(target)):
219 | target[j] = target[j].cuda(async=True)
220 | target_var.append(torch.autograd.Variable(target[j]))
221 |
222 | # Output of the model
223 | output, final_scores = model(input_var)
224 |
225 | # Compute loss
226 | final_loss = criterion(final_scores, target_var[0])
227 |
228 | train_loss = final_loss
229 |
230 | losses.update(train_loss.data.cpu().numpy(), input[0].size(0))
231 |
232 | # for plot
233 | outs.append(torch.max(final_scores, 1)[1].data.cpu().numpy())
234 | targets.append(target[0].cpu().numpy())
235 |
236 | # Backpropagate loss and update weights
237 | optimizer.zero_grad()
238 | train_loss.backward()
239 | optimizer.step()
240 |
241 | # Print info
242 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
243 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
244 | epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss=losses))
245 |
246 | outs = np.concatenate(outs).flatten()
247 | targets = np.concatenate(targets).flatten()
248 |
249 | acc = np.sum(outs == targets) / len(outs)
250 | return epoch, losses.avg, acc, None
251 |
252 |
253 | def valEpoch(args, val_loader, model, criterion, epoch):
254 | losses = utils.AverageMeter()
255 | # final_losses = utils.AverageMeter()
256 | model.eval()
257 | for batch_idx, (input, target) in enumerate(val_loader):
258 |
259 | # Inputs to Variable type
260 | input_var = list()
261 | for j in range(len(input)):
262 | input_var.append(torch.autograd.Variable(input[j]).cuda())
263 |
264 | # Targets to Variable type
265 | target_var = list()
266 | for j in range(len(target)):
267 | target[j] = target[j].cuda(async=True)
268 | target_var.append(torch.autograd.Variable(target[j]))
269 |
270 | # Output of the model
271 | with torch.no_grad():
272 | output, final_scores = model(input_var)
273 |
274 | # Compute loss
275 | predicted = torch.max(final_scores, 1)[1]
276 |
277 | stream_predictions = [torch.max(p, 1)[1] for p in output]
278 |
279 | final_loss = criterion(final_scores, target_var[0])
280 | train_loss = 0
281 |
282 | train_loss = final_loss
283 |
284 | losses.update(train_loss.data.cpu().numpy(), input[0].size(0))
285 |
286 | # Save predictions to compute accuracy
287 | if batch_idx == 0:
288 | out = predicted.data.cpu().numpy()
289 | out_stream_list = []
290 | for p in stream_predictions:
291 | out_stream_list.append(p.data.cpu().numpy())
292 | label = target[0].cpu().numpy()
293 | else:
294 | out = np.concatenate((out, predicted.data.cpu().numpy()), axis=0)
295 | label = np.concatenate((label, target[0].cpu().numpy()), axis=0)
296 | for idx in range(len(stream_predictions)):
297 | out_stream_list[idx] = np.concatenate(
298 | (out_stream_list[idx], stream_predictions[idx].data.cpu().numpy()), axis=0)
299 |
300 | # Accuracy
301 | acc = np.sum(out == label) / len(out)
302 | logger.info('Validation set: Average loss: {:.4f}\t'
303 | 'Accuracy {acc}'.format(losses.avg, acc=acc))
304 |
305 | return epoch, losses.avg, acc, None
306 |
307 |
308 | def train(args, modeldir):
309 | # Set GPU
310 | n_gpu = torch.cuda.device_count()
311 | logger.info("device: {} n_gpu: {}".format(args.device, n_gpu))
312 | random.seed(args.seed)
313 | np.random.seed(args.seed)
314 | torch.manual_seed(args.seed)
315 | if n_gpu > 0:
316 | torch.cuda.manual_seed_all(args.seed)
317 |
318 | # Model, optimizer and loss
319 |
320 | if args.fusion_method == "product":
321 | model = FusionProduct(args)
322 | else:
323 | model = FusionMethods(args)
324 | if args.device == "cuda":
325 | model.cuda()
326 |
327 | if n_gpu > 1:
328 | model = torch.nn.DataParallel(model)
329 |
330 | print(model)
331 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
332 | scheduler = ReduceLROnPlateau(optimizer, patience=args.lr_patience)
333 |
334 | class_loss = nn.CrossEntropyLoss().cuda()
335 |
336 | # Data
337 | trainDataObject = FusionDataSample(args, split='train')
338 | valDataObject = FusionDataSample(args, split='val')
339 | train_loader = torch.utils.data.DataLoader(trainDataObject, batch_size=args.batch_size, shuffle=True,
340 | pin_memory=True, num_workers=args.workers)
341 | val_loader = torch.utils.data.DataLoader(valDataObject, batch_size=args.batch_size, shuffle=True, pin_memory=True,
342 | num_workers=args.workers)
343 |
344 | # Now, let's start the training process!
345 | logger.info('Training loader with %d samples' % train_loader.__len__())
346 | logger.info('Validation loader with %d samples' % val_loader.__len__())
347 | logger.info('Training...')
348 | pattrack = 0
349 | best_val = 0
350 |
351 |
352 | for epoch in range(0, args.nepochs):
353 |
354 | # Epoch
355 | trainEpoch(args, train_loader, model, class_loss, optimizer, epoch)
356 | epoch_plot_val, loss_plot_val, acc_plot_val, stream_losses_plot_val = valEpoch(args, val_loader, model,
357 | class_loss, epoch)
358 | current_val = acc_plot_val
359 |
360 | scheduler.step(loss_plot_val)
361 |
362 | # Check patience
363 | is_best = current_val > best_val
364 | best_val = max(current_val, best_val)
365 | if not is_best:
366 | pattrack += 1
367 | else:
368 | pattrack = 0
369 | if pattrack >= args.patience:
370 | break
371 |
372 | logger.info('** Validation information: %f (this accuracy) - %f (best accuracy) - %d (patience valtrack)' % (
373 | current_val, best_val, pattrack))
374 |
375 | # Save
376 | state = {'state_dict': model.state_dict(),
377 | 'best_val': best_val,
378 | 'optimizer': optimizer.state_dict(),
379 | 'pattrack': pattrack,
380 | 'curr_val': current_val}
381 | filename = os.path.join(modeldir, 'model_latest.pth.tar')
382 | torch.save(state, filename)
383 | if is_best:
384 | filename = os.path.join(modeldir, 'model_best.pth.tar')
385 | torch.save(state, filename)
386 |
387 |
388 | def evaluate(args, modeldir):
389 | n_gpu = torch.cuda.device_count()
390 |
391 | """Code by InterDigital"""
392 | if args.fusion_method == "product":
393 | model = FusionProduct(args)
394 | else:
395 | model = FusionMethods(args)
396 |
397 | """Code by Noa Garcia and Yuta Nakashima"""
398 | if args.device == "cuda":
399 | model.cuda()
400 |
401 | if n_gpu > 1:
402 | model = torch.nn.DataParallel(model)
403 |
404 | logger.info("=> loading checkpoint from '{}'".format(modeldir))
405 | checkpoint = torch.load(os.path.join(modeldir, 'model_best.pth.tar'))
406 | model.load_state_dict(checkpoint['state_dict'])
407 |
408 | # Data
409 | evalDataObject = FusionDataSample(args, split='test')
410 | test_loader = torch.utils.data.DataLoader(evalDataObject, batch_size=args.batch_size, shuffle=False,
411 | pin_memory=(not args.no_cuda), num_workers=args.workers)
412 | logger.info('Evaluation loader with %d samples' % test_loader.__len__())
413 |
414 | # Switch to evaluation mode & compute test samples embeddings
415 | batch_time = utils.AverageMeter()
416 | end = time.time()
417 | model.eval()
418 | for i, (input, target) in enumerate(test_loader):
419 |
420 | # Inputs to Variable type
421 | input_var = list()
422 | for j in range(len(input)):
423 | input_var.append(torch.autograd.Variable(input[j]).cuda())
424 |
425 | # Targets to Variable type
426 | target_var = list()
427 | for j in range(len(target)):
428 | target[j] = target[j].cuda(async=True)
429 | target_var.append(torch.autograd.Variable(target[j]))
430 |
431 | # Output of the model
432 | with torch.no_grad():
433 | output, final_scores = model(input_var)
434 | # Compute final loss
435 | predicted = torch.max(final_scores, 1)[1]
436 |
437 | # measure elapsed time
438 | batch_time.update(time.time() - end)
439 | end = time.time()
440 |
441 | # Store outputs
442 | if i == 0:
443 | out = predicted.data.cpu().numpy()
444 | label = target[0].cpu().numpy()
445 | index = target[1].cpu().numpy()
446 |
447 | score_list = []
448 | for o in output:
449 | score_list.append(o.data.cpu().numpy())
450 | scores_final = final_scores.data.cpu().numpy()
451 | else:
452 | out = np.concatenate((out, predicted.data.cpu().numpy()), axis=0)
453 | label = np.concatenate((label, target[0].cpu().numpy()), axis=0)
454 | index = np.concatenate((index, target[1].cpu().numpy()), axis=0)
455 |
456 | for idx in range(len(score_list)):
457 | score_list[idx] = np.concatenate((score_list[idx], output[idx].data.cpu().numpy()), axis=0)
458 |
459 | scores_final = np.concatenate((scores_final, final_scores.cpu().numpy()), axis=0)
460 |
461 | df = pd.read_csv(os.path.join(args.data_dir, 'knowit_data_test.csv'), delimiter='\t')
462 |
463 | logger.info("Eval on test data from fusion final output")
464 |
465 | """Code by InterDigital"""
466 | with open(os.path.join(modeldir, 'test_results_fusion.npy'), 'wb') as f:
467 | np.save(f, out)
468 | with open(os.path.join(modeldir, 'test_labels_fusion.npy'), 'wb') as f:
469 | np.save(f, label)
470 | """Code by Noa Garcia and Yuta Nakashima"""
471 | utils.accuracy(df, out, label, index)
472 |
473 |
474 | if __name__ == "__main__":
475 |
476 | args = get_params()
477 |
478 | """Code by InterDigital"""
479 | assert len(args.fuse_stream_list) >= 2 # Make sure at least two streams given
480 |
481 | if args.load_pretrained_model_exists:
482 | if args.pretrain_modeldir is not None:
483 | modeldir = args.pretrain_modeldir
484 | else:
485 | raise FileNotFoundError
486 | else:
487 | # Create training and data directories
488 | modeldir = create_folder_with_timestamp(os.path.join(args.fusion_train_folder_path, "-".join(args.fuse_stream_list)+'_'+args.fusion_method),
489 | args.load_pretrained_model_exists)
490 |
491 | global attention_scores
492 | if args.save_multi_stream_attention_scores:
493 | attention_scores = []
494 |
495 | global save_multi_stream_scores
496 | save_multi_stream_scores = args.save_multi_stream_attention_scores
497 |
498 |
499 | args.modeldir = modeldir
500 |
501 | logger.info("Arguments: %s" % json.JSONEncoder().encode(vars(args)))
502 |
503 | with open(os.path.join(modeldir, "args.json"), 'w') as f:
504 | json.dump(vars(args), f)
505 |
506 | """Code by Noa Garcia and Yuta Nakashima"""
507 | # Train if model does not exist
508 | if not os.path.isfile(os.path.join(modeldir, 'model_best.pth.tar')):
509 | train(args, modeldir)
510 |
511 | # Evaluation
512 | evaluate(args, modeldir)
513 |
514 | """Code by InterDigital"""
515 | if args.save_multi_stream_attention_scores:
516 | with open(os.path.join(modeldir, "test_soft_attention_score.npy"), "wb") as f:
517 | np.save(f, np.array([a.cpu().numpy() for a in attention_scores]))
518 |
--------------------------------------------------------------------------------
/images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InterDigitalInc/DialogSummary-VideoQA/2cba0ca17024bea499b0c141516430d348594d3a/images/model.png
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT
2 |
3 | The following Limited Software Evaluation License (the “License”) constitutes an agreement between you (the “Licensee”) and InterDigital Communications, Inc, a company organized and existing under the laws of the State of Delaware, USA, with its registered offices located at 200 Bellevue Parkway, Suite 300, Wilmington, DE 19809, USA (hereinafter “InterDigital”).
4 | This License governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this License. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this License. If you do not accept all parts of the terms and conditions of this License, you cannot install, use, access nor copy the Software.
5 | Article 1. Definitions
6 | “Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question.
7 | “Authorized Purpose” means any use of the Software for fundamental research work with the exclusion of any commercial use. A commercial use includes, without limitation, any sublicense granted on the Software against a fee whatever its nature, any use of the Software in a product that is offered (either free or for a price) to any third party, any use of the Software to provide a service to a third party and/or any use of the Software to create a competing product of the Software ("Purpose")
8 | “Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this License relating to the Software, in written or electronic format, including but not limited to, technical reference manuals, technical notes, user manuals, and application guides.
9 | “Effective Date” means the date Licensee first installs a copy of the Software on any computer.
10 |
11 | “Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist.
12 | “Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents and any other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto.
13 | "Open Source Software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or statically linked to the source code of such software, released under a free or open source software license that requires, as a condition of usage, copy, modification and/or redistribution of such software, that the party:
14 | • Redistribute the Open Source Software royalty-free; and/or
15 | • Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it was originally released; and/or
16 | • Release to the public, disclose or otherwise make available the source code of the Open Source Software.
17 | For purposes of this License, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (i) GNU General Public License (GPL); (ii) GNU Lesser/Library GPL (LGPL); (iii) the Artistic License; (iv) the Mozilla Public License; (v) the Common Public License; (vi) the Sun Community Source License (SCSL); (vii) the Sun Industry Standards Source License (SISSL); (viii) BSD License; (ix) MIT License; (x) Apache Software License; (xi) Open SSL License; (xii) IBM Public License; and (xiii) Open Software License.
18 | “Software” means the Software with which this license was downloaded, namely DialogSummary-VideoQA in object code.
19 | Article 2. License
20 | InterDigital grants Licensee a free, worldwide, non-exclusive, license to InterDigital’s copyright on the Software to download, use and reproduce solely for the Authorized Purpose for the Limited Period.
21 | Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this License.
22 | Article 3. Restrictions on use of the Software
23 | Licensee shall not have the right to correct, adapt, modify, reverse engineer, disassemble, decompile or/and otherwise perform or conduct any action leading to the transformation of the Software.
24 | Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material.
25 | Licensee may reproduce and distribute copies of the Software in any medium provided that Licensee gives any other recipients of the Software a copy of this License.
26 | Article 4. Ownership
27 | Title to and ownership of the Software, the Documentation, and/or any Intellectual Property Right protecting the Software and/or the Documentation shall at all times remain with InterDigital. Licensee agrees that except for the limited rights granted to the Software as set forth in Section 2 above, in no event shall anything in this License grant, provide, or convey any other rights, privileges, immunities, or interest in or to any Intellectual Property Rights (including but not limited to patent rights) of InterDigital or any of its Affiliates, whether by implication, estoppel, or otherwise.
28 | Article 5. Publication/Communication
29 | Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (e.g., a PowerPoint presentation) relating to the Software, the following statement shall be inserted:
30 | “DialogSummary-VideoQA is an InterDigital product”
31 | In any publication, the latest publication about the software shall be properly cited. The latest publication currently is:
32 | "Deniz Engin, François Schnitzler, Ngoc Q. K. Duong, Yannis Avrithis. On the hidden treasure of dialog in video question answering. In Proc. ICCV, 2021.”
33 | In any oral communication relating to the Software and/or its use, the Licensee shall orally indicate that the Software is InterDigital’s property.
34 | Article 6. No Warranty - Disclaimer
35 | THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE SOFTWARE WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE SOFTWARE SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE. ANY AND ALL SUCH IMPLIED WARRANTIES ARE FULLY DISCLAIMED BY INTERDIGITAL TO THE MAXIMUM EXTENT ALLOWED BY LAW, AND LICENSEE ACKNOWLEDGES THAT THIS DISCLAIMER OF ALL EXPRESS AND IMPLIED WARRANTIES BY INTERDIGITAL, AS WELL AS LICENSEE’S ACCEPTANCE AND ACKNOWLEDGEMENT OF THE SAME, IS A MATERIAL PART OF THE CONSIDERATION FOR THIS LICENSE.
36 | InterDigital shall not be obligated to perform or provide any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or Documentation, or to fix any bug that could arise.
37 | Licensee at all times uses the Software at its own cost, risk and responsibility. InterDigital shall not be liable for any damages that could accrue by or to Licensee as a result of its use of the Software, either in accordance with this License or not.
38 | InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into this License unless expressly set out in this License, or arising from gross negligence, willful misconduct or fraud.
39 | Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party claims, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use, or any other activity in relation to this Software.
40 | Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party.
41 | Article 7. Open Source Software
42 | Licensee hereby represents, warrants, and covenants to InterDigital that Licensee’s use of the Software shall not result in the Contamination of all or any part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates.
43 | As used herein, “Contamination” shall mean that the licensing terms under which any Open Source Software, distinct from the Software, is released would also apply to the Software herein, by virtue of such Open Source Software being linked to, combined with, or otherwise connected to the Software.
44 | Licensee agree that some Open Source Software are included in the distribution. A list of such is provided in exhibit A with the relevant licenses applicable. For the avoidance of doubt, regarding such open source parts, the relevant license will apply exclusively.
45 | Article 8. No Future Contract Obligation
46 | Neither this License nor the furnishing of the Software, nor any other InterDigital information provided to Licensee, shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party.
47 | Article 9. General Provisions
48 | 9.1 Severability. If any provision of this License shall be held to be in contravention of applicable law, this License shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect.
49 | 9.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this License, this License and all of the rights of the parties under this License shall be governed by, construed under and enforced in accordance with the substantive law of the State of Delaware, USA, without regard to conflicts of law principles. In case of a dispute that cannot be settled amicably, the state and federal courts located in New Castle County, Delaware, USA, shall have exclusive jurisdiction over such dispute, and each party hereby irrevocably waives any objection to the jurisdiction of such courts, including but not limited to objections of lack of in personam jurisdiction or based on principles of forum non conveniens.
50 | 9.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 8, 9.1, 9.2 and 9.5 shall survive termination of this License.
51 | 9.4 Assignment. InterDigital may assign this license to any third Party. Licensee may not assign this agreement to any third party without InterDigital’s prior written approval.
52 | 9.5 Entire Agreement. This License constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding.
53 |
54 |
55 | Exhibit A – Open source parts
56 |
57 | • Pytorch, https://github.com/pytorch/pytorch (View license)
58 | • Numpy, https://github.com/numpy/numpy (BSD-3-Clause License)
59 | • torchvision, https://github.com/pytorch/vision (BSD 3-Clause License)
60 | • pytorch-transformers, https://github.com/huggingface/transformers/ (Apache-2.0 License)
61 | • tqdm, https://github.com/tqdm/tqdm (View license)
62 | • pandas, https://github.com/pandas-dev/pandas (BSD-3-Clause License)
63 | • scikit-learn, https://github.com/scikit-learn/scikit-learn (BSD-3-Clause License)
64 | • ROLL-VideoQA, https://github.com/noagarcia/ROLL-VideoQA (MIT License)
65 | • Multi head attention, https://github.com/CyberZHG/torch-multi-head-attention/blob/master/torch_multi_head_attention/multi_head_attention.py (MIT License)
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/multi_head_attention.py:
--------------------------------------------------------------------------------
1 | """This code is taken from https://github.com/CyberZHG/torch-multi-head-attention/blob/master/torch_multi_head_attention/multi_head_attention.py"""
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | __all__ = ['MultiHeadAttention', 'ScaledDotProductAttention']
10 |
11 |
12 | class ScaledDotProductAttention(nn.Module):
13 |
14 | def forward(self, query, key, value, mask=None):
15 | dk = query.size()[-1]
16 | scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
17 | if mask is not None:
18 | scores = scores.masked_fill(mask == 0, -1e9)
19 | attention = F.softmax(scores, dim=-1)
20 | return attention.matmul(value)
21 |
22 |
23 | class MultiHeadAttention(nn.Module):
24 |
25 | def __init__(self,
26 | in_features,
27 | head_num,
28 | bias=True,
29 | activation=F.relu):
30 | """Multi-head attention.
31 |
32 | :param in_features: Size of each input sample.
33 | :param head_num: Number of heads.
34 | :param bias: Whether to use the bias term.
35 | :param activation: The activation after each linear transformation.
36 | """
37 | super(MultiHeadAttention, self).__init__()
38 | if in_features % head_num != 0:
39 | raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
40 | self.in_features = in_features
41 | self.head_num = head_num
42 | self.activation = activation
43 | self.bias = bias
44 | self.linear_q = nn.Linear(in_features, in_features, bias)
45 | self.linear_k = nn.Linear(in_features, in_features, bias)
46 | self.linear_v = nn.Linear(in_features, in_features, bias)
47 | self.linear_o = nn.Linear(in_features, in_features, bias)
48 |
49 | def forward(self, q, k, v, mask=None):
50 | q, k, v = self.linear_q(q), self.linear_k(k), self.linear_v(v)
51 | if self.activation is not None:
52 | q = self.activation(q)
53 | k = self.activation(k)
54 | v = self.activation(v)
55 |
56 | q = self._reshape_to_batches(q)
57 | k = self._reshape_to_batches(k)
58 | v = self._reshape_to_batches(v)
59 | if mask is not None:
60 | mask = mask.repeat(self.head_num, 1, 1)
61 | y = ScaledDotProductAttention()(q, k, v, mask)
62 | y = self._reshape_from_batches(y)
63 |
64 | y = self.linear_o(y)
65 | if self.activation is not None:
66 | y = self.activation(y)
67 | return y
68 |
69 | @staticmethod
70 | def gen_history_mask(x):
71 | """Generate the mask that only uses history data.
72 |
73 | :param x: Input tensor.
74 | :return: The mask.
75 | """
76 | batch_size, seq_len, _ = x.size()
77 | return torch.tril(torch.ones(seq_len, seq_len)).view(1, seq_len, seq_len).repeat(batch_size, 1, 1)
78 |
79 | def _reshape_to_batches(self, x):
80 | batch_size, seq_len, in_feature = x.size()
81 | sub_dim = in_feature // self.head_num
82 | return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
83 | .permute(0, 2, 1, 3)\
84 | .reshape(batch_size * self.head_num, seq_len, sub_dim)
85 |
86 | def _reshape_from_batches(self, x):
87 | batch_size, seq_len, in_feature = x.size()
88 | batch_size //= self.head_num
89 | out_dim = in_feature * self.head_num
90 | return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
91 | .permute(0, 2, 1, 3)\
92 | .reshape(batch_size, seq_len, out_dim)
93 |
94 | def extra_repr(self):
95 | return 'in_features={}, head_num={}, bias={}, activation={}'.format(
96 | self.in_features, self.head_num, self.bias, self.activation,
97 | )
--------------------------------------------------------------------------------
/stream_data_sample.py:
--------------------------------------------------------------------------------
1 | """Code by Noa Garcia and Yuta Nakashima"""
2 | import logging
3 | import math
4 | import os
5 | from abc import ABC
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | import torch.utils.data as data
11 |
12 | from utils import SCENE_BASED_STREAMS, EPISODE_BASED_STREAMS, clean_html, truncate_seq_pair_inv, load_knowit_data, \
13 | SCENE_SUMMARY_CSV, EPISODE_SUMMARY_CSV, TBBT_SUMMARIES_CSV, SCENES_DESCRIPTIONS_CSV
14 |
15 |
16 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
17 | datefmt='%m/%d/%Y %H:%M:%S',
18 | level=logging.INFO)
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class DataSample(object):
23 |
24 | def __init__(self, qid, question, answer1, answer2, answer3, answer4, subtitles, scene_description, knowledge,
25 | label, summary):
26 | """
27 |
28 | :param qid:
29 | :param question:
30 | :param answer1:
31 | :param answer2:
32 | :param answer3:
33 | :param answer4:
34 | :param subtitles:
35 | :param scene_description:
36 | :param knowledge:
37 | :param label:
38 | :param summary:
39 | """
40 | self.qid = qid
41 | self.question = question
42 | self.subtitles = subtitles
43 | self.knowledge = knowledge
44 | self.label = label
45 | self.scene_description = scene_description
46 | self.answers = [
47 | answer1,
48 | answer2,
49 | answer3,
50 | answer4,
51 | ]
52 | self.summary = summary
53 |
54 |
55 | """Code by InterDigital"""
56 | class DataloaderFactory:
57 | @staticmethod
58 | def build(args, split, tokenizer):
59 | stream_name = args.train_name
60 | if stream_name in SCENE_BASED_STREAMS:
61 | return SceneInputBasedStreamData(args, split, tokenizer)
62 | elif stream_name in EPISODE_BASED_STREAMS:
63 | return EpisodeInputBasedStreamData(args, split, tokenizer)
64 | else:
65 | raise NotImplementedError
66 |
67 | """Code by Noa Garcia and Yuta Nakashima"""
68 | def get_qa_labels(df, index, row):
69 | question = row['question']
70 | answer1 = row['answer1']
71 | answer2 = row['answer2']
72 | answer3 = row['answer3']
73 | answer4 = row['answer4']
74 | label = int(df['idxCorrect'].iloc[index] - 1)
75 | return answer1, answer2, answer3, answer4, label, question
76 |
77 |
78 | class Dataloader(data.Dataset, ABC):
79 | def __init__(self, args, split, tokenizer):
80 | self.df = load_knowit_data(args, split)
81 | self.tokenizer = tokenizer
82 | self.split = split
83 | self.args = args
84 | self.max_seq_length = args.max_seq_length
85 | self.samples = self.get_data(self.df)
86 | self.num_samples = len(self.samples)
87 |
88 | def get_data(self, df):
89 | raise NotImplementedError
90 |
91 | def __len__(self):
92 | return self.num_samples
93 |
94 | """Code by InterDigital"""
95 | class EpisodeInputBasedStreamData(Dataloader):
96 | def __init__(self, args, split, tokenizer):
97 | if args.train_name == "plot":
98 | dfkg = pd.read_csv(os.path.join(args.data_dir, TBBT_SUMMARIES_CSV))
99 | self.recap_dict = dfkg.set_index('Episode').T.to_dict('list')
100 | elif args.train_name == "episode_dialog_summary":
101 | episode_summary_df = pd.read_csv(os.path.join(args.data_dir, EPISODE_SUMMARY_CSV),sep='\t')
102 | self.episode_summary_dict = episode_summary_df.set_index("episode_name").episode_summary.to_dict()
103 | else:
104 | raise NotImplementedError
105 |
106 | super().__init__(args, split, tokenizer)
107 | self.num_max_slices = args.num_max_slices
108 | self.stride = args.seq_stride
109 |
110 | logger.info('Data loader ready with {:d} samples'.format(self.num_samples))
111 |
112 | """Code by Noa Garcia and Yuta Nakashima"""
113 | def get_data(self, df):
114 | samples = []
115 | for index, row in df.iterrows():
116 | answer1, answer2, answer3, answer4, label, question = get_qa_labels(df, index, row)
117 | """Code by InterDigital"""
118 | if self.args.train_name == "episode_dialog_summary":
119 | episode = row.scene[:6]
120 | plot_summary = self.episode_summary_dict[episode]
121 |
122 | elif self.args.train_name == "plot":
123 | episode = row.scene[:6]
124 | season = episode[1:3]
125 | number = episode[4:6]
126 | idepi = int(str(int(season)) + number)
127 | plot_summary = self.recap_dict[idepi][0]
128 | else:
129 | raise NotImplementedError
130 | """Code by Noa Garcia and Yuta Nakashima"""
131 | samples.append(DataSample(qid=index, question=question, answer1=answer1, answer2=answer2, answer3=answer3,
132 | answer4=answer4, subtitles=None, scene_description=None, knowledge=plot_summary,
133 | label=label,
134 | summary=None))
135 | return samples
136 |
137 | def __getitem__(self, index):
138 | """
139 | Convert each sample into 4*num_max_slices BERT input sequences as:
140 |
141 | [CLS] + kg_part_1 + question + [SEP] + answer1 + [SEP]
142 | [CLS] + kg_part_1 + question + [SEP] + answer2 + [SEP]
143 | [CLS] + kg_part_1 + question + [SEP] + answer3 + [SEP]
144 | [CLS] + kg_part_1 + question + [SEP] + answer4 + [SEP]
145 |
146 | [CLS] + kg_part_2 + question + [SEP] + answer1 + [SEP]
147 | [CLS] + kg_part_2 + question + [SEP] + answer2 + [SEP]
148 | .
149 | .
150 | .
151 | [CLS] + kg_part_num_max_slices + question + [SEP] + answer4 + [SEP]
152 |
153 | sample = self.samples[index]
154 | :param index:
155 | """
156 | sample = self.samples[index]
157 | question_tokens = self.tokenizer.tokenize(sample.question)
158 | all_knowledge_tokens = self.tokenizer.tokenize(sample.knowledge)
159 | list_answer_tokens = []
160 | for answer in sample.answers:
161 | answer_tokens = self.tokenizer.tokenize(answer)
162 | list_answer_tokens.append(answer_tokens)
163 |
164 | # Compute maximum window length for knowledge slices based on question and answer lengths
165 | max_qa_len = len(question_tokens) + max([len(a) for a in list_answer_tokens])
166 | len_extra_tokens = 3
167 | len_kg_window = self.max_seq_length - max_qa_len - len_extra_tokens
168 |
169 | # Slice knowledge according to window and stride
170 | list_knowledge_tokens = []
171 |
172 | num_kg_pieces = min(math.ceil((len(all_knowledge_tokens) - len_kg_window) / self.stride) + 1,
173 | self.num_max_slices)
174 | num_kg_pieces = max(num_kg_pieces, 1)
175 | for n in list(range(num_kg_pieces)):
176 | maxpos = min(len_kg_window + (self.stride * n), len(all_knowledge_tokens))
177 | tokens = all_knowledge_tokens[self.stride * n:maxpos]
178 | list_knowledge_tokens.append(tokens)
179 |
180 | # Transformer input features
181 | sample_input_ids = np.zeros((self.num_max_slices, len(sample.answers), self.max_seq_length))
182 | sample_input_mask = np.zeros((self.num_max_slices, len(sample.answers), self.max_seq_length))
183 | sample_segment_ids = np.zeros((self.num_max_slices, len(sample.answers), self.max_seq_length))
184 | for kg_index, knowledge_tokens in enumerate(list_knowledge_tokens):
185 | for answer_index, answer_tokens in enumerate(list_answer_tokens):
186 | """Code by InterDigital"""
187 | start_tokens = knowledge_tokens[:] + question_tokens[:]
188 | ending_tokens = answer_tokens
189 |
190 | """Code by Noa Garcia and Yuta Nakashima"""
191 | sequence_tokens = [self.tokenizer.cls_token] + start_tokens + [
192 | self.tokenizer.sep_token] + ending_tokens + [self.tokenizer.sep_token]
193 | segment_ids = [0] * (len(start_tokens) + 2) + [1] * (len(ending_tokens) + 1)
194 | input_ids = self.tokenizer.convert_tokens_to_ids(sequence_tokens)
195 | input_mask = [1] * len(input_ids)
196 |
197 | padding = [self.tokenizer.pad_token_id] * (self.max_seq_length - len(input_ids))
198 | input_ids += padding
199 | input_mask += padding
200 | segment_ids += padding
201 |
202 | sample_input_ids[kg_index, answer_index, :] = input_ids
203 | sample_input_mask[kg_index, answer_index, :] = input_mask
204 | sample_segment_ids[kg_index, answer_index, :] = segment_ids
205 |
206 | sample_input_ids = torch.tensor(sample_input_ids, dtype=torch.long)
207 | sample_input_mask = torch.tensor(sample_input_mask, dtype=torch.long)
208 | sample_segment_ids = torch.tensor(sample_segment_ids, dtype=torch.long)
209 | qid = torch.tensor(sample.qid, dtype=torch.long)
210 | label = torch.tensor(sample.label, dtype=torch.long)
211 | return sample_input_ids, sample_input_mask, sample_segment_ids, qid, label
212 |
213 | """Code by InterDigital"""
214 | class SceneInputBasedStreamData(Dataloader):
215 | def __init__(self, args, split, tokenizer):
216 | super().__init__(args, split, tokenizer)
217 | self.num_samples = len(self.samples)
218 | logger.info('Data loader ready with {:d} samples'.format(self.num_samples))
219 |
220 | def get_data(self, df):
221 | """
222 | Load data into list of DataSamples
223 | :param df:
224 | :return:
225 | """
226 | samples = []
227 |
228 | if self.args.train_name == "video":
229 | df_descriptions = pd.read_csv(os.path.join(self.args.data_dir, SCENES_DESCRIPTIONS_CSV),
230 | delimiter='\t')
231 | df_descriptions.replace(np.nan, '', inplace=True)
232 | elif self.args.train_name == "scene_dialog_summary":
233 | df_summaries = pd.read_csv(os.path.join(self.args.data_dir, SCENE_SUMMARY_CSV), sep="\t")
234 |
235 | """Code by Noa Garcia and Yuta Nakashima"""
236 | for index, row in df.iterrows():
237 | summary = None
238 | subtitles = None
239 | scene_description = None
240 | answer1, answer2, answer3, answer4, label, question = get_qa_labels(df, index, row)
241 |
242 | """Code by InterDigital"""
243 | if self.args.train_name == "dialog":
244 | subtitles = clean_html(row['subtitle'].replace('
', ' ').replace(' - ', ' '))
245 | elif self.args.train_name == "scene_dialog_summary":
246 | scene_name = row['scene']
247 | summary = df_summaries[df_summaries.scene == scene_name].summary.values[0]
248 | elif self.args.train_name == "video":
249 | scene_name = row['scene']
250 | scene_description = ''
251 | if len(df_descriptions[df_descriptions['Scene'] == scene_name]['Description']) > 0:
252 | scene_description = df_descriptions[df_descriptions['Scene'] == scene_name]['Description'].values[0]
253 | else:
254 | raise NotImplementedError
255 |
256 | """Code by Noa Garcia and Yuta Nakashima"""
257 | samples.append(DataSample(qid=index, question=question, answer1=answer1, answer2=answer2, answer3=answer3,
258 | answer4=answer4, subtitles=subtitles, scene_description=scene_description,
259 | knowledge=None,
260 | label=label, summary=summary))
261 | return samples
262 |
263 | def __getitem__(self, index):
264 | """
265 | Convert each sample into 4 BERT input sequences as:
266 | [CLS] + subtitles + question + [SEP] + answer1 + [SEP]
267 | [CLS] + subtitles + question + [SEP] + answer2 + [SEP]
268 | [CLS] + subtitles + question + [SEP] + answer3 + [SEP]
269 | [CLS] + subtitles + question + [SEP] + answer4 + [SEP]
270 | :param index:
271 | :return:
272 | """
273 |
274 | sample = self.samples[index]
275 |
276 | """Code by InterDigital"""
277 | train_name = self.args.train_name
278 | if train_name == "dialog":
279 | text_tokens = self.tokenizer.tokenize(sample.subtitles)
280 | elif train_name == "scene_dialog_summary":
281 | text_tokens = self.tokenizer.tokenize(sample.summary)
282 | elif train_name == "video":
283 | text_tokens = self.tokenizer.tokenize(sample.scene_description)
284 | else:
285 | raise NotImplementedError
286 |
287 | """Code by Noa Garcia and Yuta Nakashima"""
288 | question_tokens = self.tokenizer.tokenize(sample.question)
289 | choice_features = []
290 | for answer_index, answer in enumerate(sample.answers):
291 | start_tokens = text_tokens[:] + question_tokens[:]
292 | ending_tokens = self.tokenizer.tokenize(answer)
293 | truncate_seq_pair_inv(start_tokens, ending_tokens, self.max_seq_length - 3)
294 | tokens = [self.tokenizer.cls_token] + start_tokens + [self.tokenizer.sep_token] + ending_tokens + [
295 | self.tokenizer.sep_token]
296 | segment_ids = [0] * (len(start_tokens) + 2) + [1] * (len(ending_tokens) + 1)
297 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
298 | input_mask = [1] * len(input_ids)
299 |
300 | padding = [self.tokenizer.pad_token_id] * (self.max_seq_length - len(input_ids))
301 | input_ids += padding
302 | input_mask += padding
303 | segment_ids += padding
304 |
305 | assert len(input_ids) == self.max_seq_length
306 | assert len(input_mask) == self.max_seq_length
307 | assert len(segment_ids) == self.max_seq_length
308 |
309 | choice_features.append((tokens, input_ids, input_mask, segment_ids))
310 |
311 | input_ids = torch.tensor([data[1] for data in choice_features], dtype=torch.long)
312 | input_mask = torch.tensor([data[2] for data in choice_features], dtype=torch.long)
313 | segment_ids = torch.tensor([data[3] for data in choice_features], dtype=torch.long)
314 | qid = torch.tensor(sample.qid, dtype=torch.long)
315 | label = torch.tensor(sample.label, dtype=torch.long)
316 | return input_ids, input_mask, segment_ids, qid, label
317 |
--------------------------------------------------------------------------------
/stream_main_train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Code by Noa Garcia and Yuta Nakashima"""
3 | import argparse
4 | import json
5 | import logging
6 | import os
7 | import random
8 | import sys
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 |
14 | from stream_data_sample import DataloaderFactory
15 | from train_stream import stream_training, stream_embeddings
16 | from utils import EPISODE_BASED_STREAMS, create_folder, str2bool
17 |
18 | from pytorch_transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
19 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel
20 | from pytorch_transformers.tokenization_bert import BertTokenizer
21 |
22 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
23 | datefmt='%m/%d/%Y %H:%M:%S',
24 | level=logging.INFO)
25 | logger = logging.getLogger(__name__)
26 |
27 | np.set_printoptions(threshold=sys.maxsize)
28 |
29 |
30 | def get_params():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--data_dir", default='data/', type=str)
33 | parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
34 | parser.add_argument("--do_lower_case", default=True, type=bool)
35 | parser.add_argument('--seed', type=int, default=181)
36 | parser.add_argument("--learning_rate", default=5e-5, type=float)
37 | parser.add_argument("--num_train_epochs", default=10.0, type=float)
38 | parser.add_argument("--patience", default=3.0, type=float)
39 | parser.add_argument("--warmup_proportion", default=0.1, type=float)
40 | parser.add_argument("--device", default='cuda', type=str, help="cuda, cpu")
41 | parser.add_argument("--batch_size", default=8, type=int)
42 | parser.add_argument("--eval_batch_size", default=32, type=int)
43 | parser.add_argument("--max_seq_length", type=int)
44 | parser.add_argument("--workers", default=8)
45 | parser.add_argument("--seq_stride", default=100, type=int)
46 | parser.add_argument("--num_max_slices", default=10, type=int)
47 | parser.add_argument("--train_name", type=str, required=True, help="dialog, video, summary, episode_summary, plot")
48 |
49 | """Code by InterDigital"""
50 | parser.add_argument("--mini_batch_size", default=None, type=int)
51 | parser.add_argument("--temporal_attention_temperature", default=2, type=float)
52 | parser.add_argument("--temporal_attention",default=True, type=str2bool)
53 | parser.add_argument("--stream_train_folder_path", default='Training/', type=str)
54 |
55 | args, unknown = parser.parse_known_args()
56 | return args
57 |
58 | """Code by Noa Garcia and Yuta Nakashima"""
59 | class StreamTransformer(BertPreTrainedModel):
60 |
61 | def __init__(self, config):
62 | super(StreamTransformer, self).__init__(config)
63 | self.args = args
64 | self.bert = BertModel(config)
65 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
66 | self.classifier = nn.Linear(config.hidden_size, 1)
67 | if self.args.train_name in EPISODE_BASED_STREAMS:
68 | self.hidden_size = config.hidden_size
69 | self.init_weights()
70 |
71 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
72 | position_ids=None, head_mask=None):
73 | if self.args.train_name in EPISODE_BASED_STREAMS:
74 | num_choices = input_ids.shape[2]
75 | num_slices = input_ids.shape[1]
76 | else:
77 | num_choices = input_ids.shape[1]
78 |
79 | flat_input_ids = input_ids.view(-1, input_ids.size(-1))
80 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
81 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
82 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
83 | outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
84 | attention_mask=flat_attention_mask, head_mask=head_mask)
85 | pooled_output = outputs[1]
86 | pooled_output = self.dropout(pooled_output)
87 | logits = self.classifier(pooled_output)
88 | if self.args.train_name in EPISODE_BASED_STREAMS:
89 | unpooled_reshaped_logits = logits.view(-1, num_slices, num_choices)
90 |
91 | """Code by InterDigital"""
92 | if self.args.temporal_attention:
93 | # temporal attention
94 | a = torch.max(unpooled_reshaped_logits, dim=2)[0].unsqueeze(-1)
95 | s = nn.Softmax(dim=1)(a / self.args.temporal_attention_temperature)
96 | reshaped_logits = torch.matmul(s.transpose(1, 2), unpooled_reshaped_logits).squeeze(1)
97 |
98 | else:
99 | """Code by Noa Garcia and Yuta Nakashima"""
100 | reshaped_logits = torch.max(unpooled_reshaped_logits, dim=1)[0]
101 |
102 | pooled_output_slices = pooled_output.view(-1, num_slices, self.hidden_size)
103 | outputs = (reshaped_logits,) + (pooled_output_slices,) + (unpooled_reshaped_logits,)
104 |
105 | else:
106 | reshaped_logits = logits.view(-1, num_choices)
107 | outputs = (reshaped_logits,) + outputs[1:]
108 |
109 | if labels is not None:
110 | loss_fct = nn.CrossEntropyLoss()
111 | loss = loss_fct(reshaped_logits, labels)
112 | outputs = (loss,) + outputs
113 |
114 | return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
115 |
116 |
117 | def pretrain_stream(args):
118 | # Create training and data directories
119 | base_model_path = os.path.join(args.stream_train_folder_path, args.train_name)
120 | base_embedding_path = os.path.join(base_model_path, 'embeddings')
121 |
122 | modeldir = create_folder(base_model_path)
123 | outdatadir = create_folder(base_embedding_path)
124 |
125 | with open(os.path.join(modeldir, "args.json"), 'w') as f:
126 | json.dump(vars(args), f)
127 |
128 | # Prepare GPUs
129 | n_gpu = torch.cuda.device_count()
130 | logger.info("device: {} n_gpu: {}".format(args.device, n_gpu))
131 | random.seed(args.seed)
132 | np.random.seed(args.seed)
133 | torch.manual_seed(args.seed)
134 | if n_gpu > 0:
135 | torch.cuda.manual_seed_all(args.seed)
136 |
137 | # Load BERT tokenizer
138 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
139 |
140 | # Do training if there is not already a model in modeldir
141 | if not os.path.isfile(os.path.join(modeldir, 'pytorch_model.bin')):
142 |
143 | # Prepare model
144 | model = StreamTransformer.from_pretrained(args.bert_model, cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
145 | 'distributed_{}'.format(-1)))
146 |
147 |
148 | model.to(args.device)
149 | if n_gpu > 1:
150 | model = torch.nn.DataParallel(model)
151 |
152 | # Load training data
153 | trainDataObject = DataloaderFactory.build(args, split='train', tokenizer=tokenizer)
154 | valDataObject = DataloaderFactory.build(args, split='val', tokenizer=tokenizer)
155 |
156 | # Start training
157 | logger.info('*** %s stream training ***' % args.train_name)
158 | stream_training(args, model, modeldir, n_gpu, trainDataObject, valDataObject)
159 |
160 | # For extracting stream embeddings, load trained weights
161 | model = StreamTransformer.from_pretrained(modeldir)
162 | model.to(args.device)
163 | if n_gpu > 1:
164 | model = torch.nn.DataParallel(model)
165 |
166 | # Get stream embeddings for each dataset split
167 | logger.info('*** Get %s stream embeddings for each data split ***' % args.train_name)
168 |
169 | """Code by InterDigital"""
170 | for split in ["train", "val", "test"]:
171 | data_object = DataloaderFactory.build(args, split=split, tokenizer=tokenizer)
172 | stream_embeddings(args, model, outdatadir, data_object, split=split)
173 | logger.info('*** Pretraining %s stream done!' % args.train_name)
174 |
175 | """Code by Noa Garcia and Yuta Nakashima"""
176 | if __name__ == "__main__":
177 | global args
178 | args = get_params()
179 |
180 | """Code by InterDigital"""
181 | logger.info("Arguments: %s" % json.JSONEncoder().encode(vars(args)))
182 |
183 | """Code by Noa Garcia and Yuta Nakashima"""
184 | pretrain_stream(args)
185 |
--------------------------------------------------------------------------------
/train_stream.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Code by Noa Garcia and Yuta Nakashima"""
3 | import logging
4 | import os
5 | import sys
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm, trange
12 | from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
13 |
14 | import utils
15 | from utils import EPISODE_BASED_STREAMS, KNOWIT_DATA_TEST_CSV
16 |
17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
18 | datefmt='%m/%d/%Y %H:%M:%S',
19 | level=logging.INFO)
20 | logger = logging.getLogger(__name__)
21 |
22 | np.set_printoptions(threshold=sys.maxsize)
23 |
24 |
25 | def train_epoch(args, model, train_dataloader, optimizer, max_grad_norm, scheduler, n_gpu, epoch):
26 |
27 | """Code by InterDigital"""
28 | if args.mini_batch_size != None:
29 | train_epoch_in_minibatch(args, model, train_dataloader, optimizer, max_grad_norm, scheduler, n_gpu, epoch)
30 | else:
31 | """Code by Noa Garcia and Yuta Nakashima"""
32 | losses = utils.AverageMeter()
33 | model.train()
34 | for step, batch in enumerate(tqdm(train_dataloader, desc="Train iter")):
35 | batch = tuple(t.to(args.device) for t in batch)
36 | input_ids, input_mask, segment_ids, qid, truelabel = batch[:5]
37 | outputs = model(input_ids, segment_ids, input_mask, truelabel)
38 | loss = outputs[0]
39 | if n_gpu > 1:
40 | loss = loss.mean() # mean() to average on multi-gpu.
41 | losses.update(loss.item(), input_ids.shape[0])
42 | loss.backward()
43 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
44 | optimizer.step()
45 | scheduler.step()
46 | optimizer.zero_grad()
47 |
48 | """Code by InterDigital"""
49 | def parse_chunk(batch, chunk_size):
50 | (input_ids, input_mask, segment_ids, qid, truelabel) = batch
51 | # print(a.shape)
52 | for i in range(input_ids.shape[0] // chunk_size):
53 | beginning = i * chunk_size
54 | end = (i + 1) * chunk_size
55 | yield input_ids[beginning:end], input_mask[beginning:end], segment_ids[beginning:end], qid[
56 | beginning:end], truelabel[
57 | beginning:end]
58 |
59 | def train_epoch_in_minibatch(args, model, train_dataloader, optimizer, max_grad_norm, scheduler, n_gpu, epoch):
60 | losses = utils.AverageMeter()
61 | model.train()
62 | for step, batch in enumerate(tqdm(train_dataloader, desc="Train iter")):
63 | for input_ids, input_mask, segment_ids, qid, truelabel in parse_chunk(batch[:5], args.mini_batch_size):
64 | input_ids = input_ids.to(args.device)
65 | input_mask = input_mask.to(args.device)
66 | segment_ids = segment_ids.to(args.device)
67 | qid = qid.to(args.device)
68 | truelabel = truelabel.to(args.device)
69 |
70 | outputs = model(input_ids, segment_ids, input_mask, truelabel)
71 | loss = outputs[0]
72 |
73 | if n_gpu > 1:
74 | loss = loss.mean() # mean() to average on multi-gpu
75 | loss = loss / (args.batch_size // args.mini_batch_size)
76 | losses.update(loss.item(), input_ids.shape[0])
77 | loss.backward()
78 |
79 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
80 |
81 | optimizer.step()
82 | scheduler.step()
83 | optimizer.zero_grad()
84 |
85 | """Code by Noa Garcia and Yuta Nakashima"""
86 | def val_epoch(args, model, val_dataloader, n_gpu, epoch):
87 | losses = utils.AverageMeter()
88 | model.eval()
89 | for step, batch in enumerate(tqdm(val_dataloader, desc="Val iter")):
90 | batch = tuple(t.to(args.device) for t in batch)
91 | input_ids, input_mask, segment_ids, qid, truelabel = batch[:5]
92 | with torch.no_grad():
93 | outputs = model(input_ids, segment_ids, input_mask, truelabel)
94 | loss, logits = outputs[:2]
95 | if n_gpu > 1:
96 | loss = loss.mean() # mean() to average on multi-gpu.
97 | logits = logits.detach().cpu().numpy()
98 | truelabel = truelabel.detach().cpu()
99 | outputs = np.argmax(logits, axis=1)
100 | losses.update(loss.item(), input_ids.shape[0])
101 | if step == 0:
102 | label = truelabel.numpy()
103 | out = outputs
104 | else:
105 | label = np.concatenate((label, truelabel.numpy()), axis=0)
106 | out = np.concatenate((out, outputs), axis=0)
107 | acc = np.sum(out == label) / len(label)
108 |
109 | return acc
110 |
111 |
112 | def stream_training(args, model, modeldir, n_gpu, trainDataObject, valDataObject):
113 | # Load data
114 | train_dataloader = torch.utils.data.DataLoader(trainDataObject, batch_size=args.batch_size, shuffle=True,
115 | pin_memory=True, num_workers=args.workers)
116 | val_dataloader = torch.utils.data.DataLoader(valDataObject, batch_size=args.eval_batch_size, shuffle=False,
117 | pin_memory=True, num_workers=args.workers)
118 | num_train_optimization_steps = int(trainDataObject.num_samples / args.batch_size) * args.num_train_epochs
119 |
120 | # Optimizer
121 | num_warmup_steps = float(args.warmup_proportion) * float(num_train_optimization_steps)
122 | max_grad_norm = 1.0
123 | param_optimizer = list(model.named_parameters())
124 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
125 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
126 | optimizer_grouped_parameters = [
127 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
128 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
129 | ]
130 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
131 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_train_optimization_steps)
132 |
133 | # Start training
134 | logger.info("Num examples = %d", train_dataloader.__len__())
135 | logger.info("Batch size = %d", args.batch_size)
136 | logger.info("Num steps = %d", num_train_optimization_steps)
137 | pattrack = 0
138 | best_val = 0
139 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
140 | train_epoch(args, model, train_dataloader, optimizer, max_grad_norm, scheduler, n_gpu, epoch)
141 | current_val = val_epoch(args, model, val_dataloader, n_gpu, epoch)
142 |
143 | # Check patience
144 | is_best = current_val > best_val
145 | best_val = max(current_val, best_val)
146 | if not is_best:
147 | pattrack += 1
148 | else:
149 | pattrack = 0
150 | if pattrack >= args.patience:
151 | break
152 |
153 | # Save a trained model
154 | if is_best:
155 | model_to_save = model.module if hasattr(model,
156 | 'module') else model # Take care of distributed/parallel training
157 | model_to_save.save_pretrained(modeldir)
158 |
159 |
160 | def stream_embeddings(args, model, outdatadir, evalDataObject, split):
161 | # Load data
162 | eval_dataloader = torch.utils.data.DataLoader(evalDataObject, batch_size=args.eval_batch_size, shuffle=False,
163 | pin_memory=True, num_workers=args.workers)
164 |
165 | # Extract embeddings
166 | logger.info("Data split : %s", split)
167 | logger.info("Num examples = %d", eval_dataloader.__len__())
168 | logger.info("Batch size = %d", args.eval_batch_size)
169 | model.eval()
170 | for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
171 | batch = tuple(t.to(args.device) for t in batch)
172 | input_ids, input_mask, segment_ids, qid, truelabel = batch[:5]
173 | with torch.no_grad():
174 | outputs = model(input_ids, segment_ids, input_mask, labels=truelabel)
175 |
176 | if args.train_name in EPISODE_BASED_STREAMS:
177 | loss, logits, cls_out, logits_slice = outputs[:4]
178 | logits_slice = logits_slice.detach().cpu().numpy()
179 | if step == 0:
180 | stream_logits_slice = logits_slice
181 | else:
182 | stream_logits_slice = np.concatenate((stream_logits_slice, logits_slice), axis=0)
183 | else:
184 | loss, logits, cls_out = outputs[:3]
185 |
186 | qid = qid.detach().cpu().numpy()
187 | logits = logits.detach().cpu().numpy()
188 | cls_out = cls_out.detach().cpu().numpy()
189 | truelabel = truelabel.detach().cpu()
190 | outputs = np.argmax(logits, axis=1)
191 |
192 | if step == 0:
193 | label = truelabel.numpy()
194 | out = outputs
195 | index = qid
196 | stream_scores = logits
197 | stream_embeddings = cls_out
198 | else:
199 | label = np.concatenate((label, truelabel.numpy()), axis=0)
200 | out = np.concatenate((out, outputs), axis=0)
201 | index = np.concatenate((index, qid), axis=0)
202 | stream_scores = np.concatenate((stream_scores, logits), axis=0)
203 | stream_embeddings = np.concatenate((stream_embeddings, cls_out), axis=0)
204 |
205 | # Save embeddings
206 | if args.train_name in EPISODE_BASED_STREAMS:
207 | stream_embeddings = (stream_embeddings, stream_logits_slice)
208 | logger.info('Saving %s embeddings for stream... %s' % (split, args.train_name))
209 | utils.save_obj(stream_scores, os.path.join(outdatadir, '%s_stream_scores_%s.pckl' % (args.train_name, split)))
210 | utils.save_obj(stream_embeddings,
211 | os.path.join(outdatadir, '%s_stream_embeddings_%s.pckl' % (args.train_name, split)))
212 |
213 | # Print accuracy on the test set
214 | if split == 'test':
215 | df = pd.read_csv(os.path.join(args.data_dir, KNOWIT_DATA_TEST_CSV), delimiter='\t')
216 | utils.accuracy(df, out, label, index)
217 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """Code by Noa Garcia and Yuta Nakashima"""
2 | import logging
3 | import os
4 | import pickle
5 | import re
6 |
7 | import argparse
8 | import pandas as pd
9 |
10 |
11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
12 | datefmt='%m/%d/%Y %H:%M:%S',
13 | level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def save_obj(obj, filename, verbose=True):
18 | f = open(filename, 'wb')
19 | pickle.dump(obj, f)
20 | f.close()
21 | if verbose:
22 | logger.info("Saved object to %s." % filename)
23 |
24 |
25 | def load_obj(filename, verbose=True):
26 | f = open(filename, 'rb')
27 | obj = pickle.load(f)
28 | f.close()
29 | if verbose:
30 | logger.info("Load object from %s." % filename)
31 | return obj
32 |
33 |
34 | class AverageMeter(object):
35 | """Computes and stores the average and current value"""
36 |
37 | def __init__(self):
38 | self.reset()
39 |
40 | def reset(self):
41 | self.val = 0
42 | self.avg = 0
43 | self.sum = 0
44 | self.count = 0
45 |
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.count += n
50 | self.avg = self.sum / self.count
51 |
52 |
53 | def accuracy(df, out, label, index):
54 | qtypes = df['QType'].to_list()
55 |
56 | acc_total, acc_vis, acc_text, acc_tem, acc_know = 0, 0, 0, 0, 0
57 | num_vis, num_text, num_tem, num_know = 0, 0, 0, 0
58 |
59 | for o, l, i in zip(out, label, index):
60 |
61 | if o == l:
62 | acc_total += 1
63 |
64 | qtype = qtypes[i]
65 |
66 | if qtype == 'visual':
67 | num_vis += 1
68 | if o == l:
69 | acc_vis += 1
70 | elif qtype == 'textual':
71 | num_text += 1
72 | if o == l:
73 | acc_text += 1
74 | elif qtype == 'temporal':
75 | num_tem += 1
76 | if o == l:
77 | acc_tem += 1
78 | elif qtype == 'knowledge':
79 | num_know += 1
80 | if o == l:
81 | acc_know += 1
82 |
83 | acc_total = acc_total / len(out)
84 | acc_vis = acc_vis / num_vis
85 | acc_text = acc_text / num_text
86 | acc_tem = acc_tem / num_tem
87 | acc_know = acc_know / num_know
88 |
89 | logger.info('--- Accuracy')
90 | logger.info('Total: %.03f' % acc_total)
91 | logger.info('Visual : %.03f' % acc_vis)
92 | logger.info('Textual : %.03f' % acc_text)
93 | logger.info('Temporal : %.03f' % acc_tem)
94 | logger.info('Knowledge : %.03f' % acc_know)
95 | logger.info('------')
96 |
97 | return acc_total, acc_vis, acc_text, acc_tem, acc_know
98 |
99 |
100 | def accuracy_val(out, label):
101 | acc_total = 0
102 |
103 | for o, l in zip(out, label):
104 |
105 | if o == l:
106 | acc_total += 1
107 |
108 | acc_total = acc_total / len(out)
109 |
110 | logger.info('--- Accuracy')
111 | logger.info('Total: %.03f' % acc_total)
112 | logger.info('------')
113 |
114 | return acc_total
115 |
116 |
117 | """Code by InterDigital"""
118 | def make_dir_if_not_exists(path):
119 | if not os.path.exists(path):
120 | os.makedirs(path)
121 | return path
122 |
123 | SCENE_BASED_STREAMS = ["dialog", "video", "scene_dialog_summary"]
124 | EPISODE_BASED_STREAMS = ["plot", "episode_dialog_summary"]
125 |
126 | SCENE_SUMMARY_CSV = "scene_summary.csv"
127 | EPISODE_SUMMARY_CSV = "episode_summary.csv"
128 | TBBT_SUMMARIES_CSV = 'tbbt_summaries.csv'
129 | SCENES_DESCRIPTIONS_CSV = 'scenes_descriptions.csv'
130 | KNOWIT_DATA_TEST_CSV = 'knowit_data_test.csv'
131 |
132 | def create_folder_with_timestamp(path, load_pretrained_model_exists):
133 | """
134 | Makes directory with timestamp suffix if a new directory needed, otherwise returns the given path
135 | :param path:
136 | :param load_pretrained_model_exists:
137 | :raise FileNotFoundError: If the given path is not exist when the pretrained model wanted to be used
138 | :return:
139 | """
140 | if not load_pretrained_model_exists:
141 | os.makedirs(path)
142 | elif not os.path.exists(path):
143 | raise FileNotFoundError
144 |
145 | return path
146 |
147 |
148 | def create_folder(path):
149 | """
150 | Makes directory if not exist, returns the given path
151 | :param path
152 | :return: path
153 | """
154 | if not os.path.exists(path):
155 | os.makedirs(path)
156 | return path
157 |
158 |
159 | def str2bool(v):
160 | if isinstance(v, bool):
161 | return v
162 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
163 | return True
164 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
165 | return False
166 | else:
167 | raise argparse.ArgumentTypeError('Boolean value expected.')
168 |
169 |
170 | """Code by Noa Garcia and Yuta Nakashima"""
171 | def clean_html(raw_html):
172 | """
173 | Cleans html tags from :param raw_html
174 | :param raw_html:
175 | :return: cleaned text
176 | """
177 | cleanr = re.compile('<.*?>')
178 | cleantext = re.sub(cleanr, '', raw_html)
179 | return cleantext
180 |
181 |
182 | def truncate_seq_pair_inv(tokens_a, tokens_b, max_length):
183 | """
184 | Truncate pair of sequences if longer than max_length
185 |
186 | :param tokens_a:
187 | :param tokens_b:
188 | :param max_length:
189 | """
190 | while True:
191 | total_length = len(tokens_a) + len(tokens_b)
192 | if total_length <= max_length:
193 | break
194 | if len(tokens_a) > len(tokens_b):
195 | tokens_a.pop(0)
196 | else:
197 | tokens_b.pop()
198 |
199 |
200 | def load_knowit_data(args, split_name):
201 | assert split_name in ["train", "val", "test"]
202 | input_file = os.path.join(args.data_dir, 'knowit_data_' + split_name + '.csv')
203 | df = pd.read_csv(input_file, delimiter='\t')
204 | logger.info('Loaded file %s.' % input_file)
205 | return df
206 |
207 |
208 |
--------------------------------------------------------------------------------