├── LICENSE ├── README.md ├── args.py ├── data.py ├── dp.pyx ├── model.py ├── setup.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, DmZhukov 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-task weakly supervised learning from instructional videos 2 | 3 | ## About 4 | This is an implementation of the paper "Cross-task weakly supervised learning from instructional videos" by D. Zhukov, J.-B. Alayrac, R. G. Cinbis, D. Fouhey, I. Laptev and J. Sivic [[arXiv](https://arxiv.org/abs/1903.08225)] 5 | 6 | Please, consider siting the paper, if you use our code or data: 7 | > @INPROCEEDINGS{Zhukov2019, 8 | > author = {Zhukov, Dimitri and Alayrac, Jean-Baptiste and Cinbis, Ramazan Gokberk and Fouhey, David and Laptev, Ivan and Sivic, Josef}, 9 | > title = {Cross-task weakly supervised learning from instructional videos}, 10 | > booktitle = CVPR, 11 | > year = {2019}, 12 | > } 13 | 14 | ## CrossTask dataset 15 | CrossTask dataset contains instructional videos, collected for 83 different tasks. 16 | For each task we provide an ordered list of steps with manual descriptions. 17 | The dataset is divided in two parts: 18 primary and 65 related tasks. 18 | Videos for the primary tasks are collected manually and provided with annotations for temporal step boundaries. 19 | Videos for the related tasks are collected automatically and don't have annotations. 20 | 21 | Tasks, video URLs and annotations are provided [here](https://www.di.ens.fr/~dzhukov/crosstask/crosstask_release.zip). See readme.txt for details. 22 | 23 | Features are available [here](https://www.di.ens.fr/~dzhukov/crosstask/crosstask_features.zip) (30Gb). Features for each video are provided in a NumPy array with one 3200-dimensional feature per second. The feature vector is a concatenation of RGB I3D features (columns 0-1023), Resnet-152 (columns 1024-3071) and audio VGG features (columns 3072-3199). 24 | 25 | Temporal constraints, extracted from narration are available [here](https://www.di.ens.fr/~dzhukov/crosstask/crosstask_constraints.zip). 26 | 27 | **Update 30/06/2019:** added videos_val.csv with validation set from the paper, removed extra lines from the constraints. 28 | 29 | **Update 14/02/2022:** Use [this](https://www.rocq.inria.fr/cluster-willow/dzhukov/missing_videos.tar.gz) link to download the videos, which are no longer available on YouTube. Subtitles for the videos are available [here](https://www.rocq.inria.fr/cluster-willow/dzhukov/crosstask-subtitles.tar.gz). 30 | 31 | ## Code 32 | Provided code can be used to train and evaluate the component model, proposed in the paper, on CrossTask dataset. 33 | It was tested with Python 3.7, PyTorch 1.0, NumPy 1.16 and Cython 0.29. 34 | 35 | 1. Clone the repository 36 | ```bash 37 | git clone https://github.com/DmZhukov/CrossTask.git 38 | cd CrossTask 39 | ``` 40 | 2. Download and unpack the dataset 41 | ```bash 42 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_release.zip 43 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_features.zip 44 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_constraints.zip 45 | unzip '*.zip' 46 | ``` 47 | 3. Compile Cython code 48 | ```bash 49 | python setup.py build_ext --inplace 50 | ``` 51 | 4. Run training 52 | ```bash 53 | python train.py 54 | ``` 55 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument( 6 | '--primary_path', 7 | type=str, 8 | default='crosstask_release/tasks_primary.txt', 9 | help='list of primary tasks') 10 | parser.add_argument( 11 | '--related_path', 12 | type=str, 13 | default='crosstask_release/tasks_related.txt', 14 | help='list of related tasks') 15 | parser.add_argument( 16 | '--annotation_path', 17 | type=str, 18 | default='crosstask_release/annotations', 19 | help='path to annotations') 20 | parser.add_argument( 21 | '--video_csv_path', 22 | type=str, 23 | default='crosstask_release/videos.csv', 24 | help='path to video csv') 25 | parser.add_argument( 26 | '--val_csv_path', 27 | type=str, 28 | default='crosstask_release/videos_val.csv', 29 | help='path to validation csv') 30 | parser.add_argument( 31 | '--features_path', 32 | type=str, 33 | default='crosstask_features', 34 | help='path to features') 35 | parser.add_argument( 36 | '--constraints_path', 37 | type=str, 38 | default='crosstask_constraints', 39 | help='path to constraints') 40 | parser.add_argument( 41 | '--n_train', 42 | type=int, 43 | default=30, 44 | help='videos per task for training') 45 | parser.add_argument( 46 | '--lr', 47 | type=float, 48 | default=1e-5, 49 | help='learning rate') 50 | parser.add_argument( 51 | '-q', 52 | type=float, 53 | default=0.7, 54 | help='regularization parameter') 55 | parser.add_argument( 56 | '--epochs', 57 | type=int, 58 | default=30, 59 | help='number of training epochs') 60 | parser.add_argument( 61 | '--pretrain_epochs', 62 | type=int, 63 | default=30, 64 | help='number of pre-training epochs') 65 | parser.add_argument( 66 | '--batch_size', 67 | type=int, 68 | default=1, 69 | ) 70 | parser.add_argument( 71 | '--num_workers', 72 | type=int, 73 | default=8, 74 | help='number of dataloader workers' 75 | ) 76 | parser.add_argument( 77 | '--use_related', 78 | type=int, 79 | default=1, 80 | help='1 for using related tasks during training, 0 for using primary tasks only' 81 | ) 82 | parser.add_argument( 83 | '--use_gpu', 84 | type=int, 85 | default=0, 86 | ) 87 | parser.add_argument( 88 | '-d', 89 | type=int, 90 | default=3200, 91 | help='dimension of feature vector', 92 | ) 93 | parser.add_argument( 94 | '--lambd', 95 | type=float, 96 | default=1e4, 97 | help='penalty coefficient for temporal cosntraints. Put 0 to use no temporal constraints during training', 98 | ) 99 | parser.add_argument( 100 | '--share', 101 | type=str, 102 | default='words', 103 | help='Level of sharing between tasks', 104 | ) 105 | args = parser.parse_args() 106 | return args -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | import numpy as np 8 | import torch as th 9 | from torch.utils.data import Dataset 10 | import math 11 | 12 | def read_task_info(path): 13 | titles = {} 14 | urls = {} 15 | n_steps = {} 16 | steps = {} 17 | with open(path,'r') as f: 18 | idx = f.readline() 19 | while idx is not '': 20 | idx = idx.strip() 21 | titles[idx] = f.readline().strip() 22 | urls[idx] = f.readline().strip() 23 | n_steps[idx] = int(f.readline().strip()) 24 | steps[idx] = f.readline().strip().split(',') 25 | next(f) 26 | idx = f.readline() 27 | return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps} 28 | 29 | def get_vids(path): 30 | task_vids = {} 31 | with open(path,'r') as f: 32 | for line in f: 33 | task, vid, url = line.strip().split(',') 34 | if task not in task_vids: 35 | task_vids[task] = [] 36 | task_vids[task].append(vid) 37 | return task_vids 38 | 39 | def read_assignment(T, K, path): 40 | Y = np.zeros([T, K], dtype=np.uint8) 41 | with open(path,'r') as f: 42 | for line in f: 43 | step,start,end = line.strip().split(',') 44 | start = int(math.floor(float(start))) 45 | end = int(math.ceil(float(end))) 46 | step = int(step) - 1 47 | Y[start:end,step] = 1 48 | return Y 49 | 50 | def random_split(task_vids, test_tasks, n_train): 51 | train_vids = {} 52 | test_vids = {} 53 | for task,vids in task_vids.items(): 54 | if task in test_tasks and len(vids) > n_train: 55 | train_vids[task] = np.random.choice(vids,n_train,replace=False).tolist() 56 | test_vids[task] = [vid for vid in vids if vid not in train_vids[task]] 57 | else: 58 | train_vids[task] = vids 59 | return train_vids, test_vids 60 | 61 | def get_A(task_steps, share="words"): 62 | """Step-to-component matrices.""" 63 | if share == 'words': 64 | # share words 65 | task_step_comps = {task: [step.split(' ') for step in steps] for task,steps in task_steps.items()} 66 | elif share == 'task_words': 67 | # share words within same task 68 | task_step_comps = {task: [[task+'_'+tok for tok in step.split(' ')] for step in steps] for task,steps in task_steps.items()} 69 | elif share == 'steps': 70 | # share whole step descriptions 71 | task_step_comps = {task: [[step] for step in steps] for task,steps in task_steps.items()} 72 | else: 73 | # no sharing 74 | task_step_comps = {task: [[task+'_'+step] for step in steps] for task,steps in task_steps.items()} 75 | vocab = [] 76 | for task,steps in task_step_comps.items(): 77 | for step in steps: 78 | vocab.extend(step) 79 | vocab = {comp: m for m,comp in enumerate(set(vocab))} 80 | M = len(vocab) 81 | A = {} 82 | for task,steps in task_step_comps.items(): 83 | K = len(steps) 84 | a = th.zeros(M, K) 85 | for k,step in enumerate(steps): 86 | a[[vocab[comp] for comp in step],k] = 1 87 | a /= a.sum(dim=0) 88 | A[task] = a 89 | return A, M 90 | 91 | class CrossTaskDataset(Dataset): 92 | def __init__(self, task_vids, n_steps, features_path, constraints_path): 93 | super(CrossTaskDataset, self).__init__() 94 | self.vids = [] 95 | for task,vids in task_vids.items(): 96 | self.vids.extend([(task,vid) for vid in vids]) 97 | self.n_steps = n_steps 98 | self.features_path = features_path 99 | self.constraints_path = constraints_path 100 | 101 | def __len__(self): 102 | return len(self.vids) 103 | 104 | def __getitem__(self, idx): 105 | task,vid = self.vids[idx] 106 | X = th.tensor(np.load(os.path.join(self.features_path,vid+'.npy')), dtype=th.float) 107 | cnst_path = os.path.join(self.constraints_path,task+'_'+vid+'.csv') 108 | C = th.tensor(1-read_assignment(X.size()[0], self.n_steps[task], cnst_path), dtype=th.float) 109 | return {'vid': vid, 'task': task, 'X': X, 'C': C} 110 | 111 | -------------------------------------------------------------------------------- /dp.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | cimport cython 4 | 5 | NP_FLOAT = np.float64 6 | NP_INT = np.int32 7 | 8 | ctypedef np.float64_t NP_FLOAT_t 9 | ctypedef np.int32_t NP_INT_t 10 | 11 | cdef int get_step(int k): 12 | return 0 if k%2==0 else (k+1)/2 13 | 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | cpdef dp(float[:,:] Y, float[:,:] C, int exactly_one=True, bg_cost=0): 17 | cdef int T = Y.shape[0] 18 | cdef int K = Y.shape[1] 19 | cdef int K_ext = 2*K+1 20 | 21 | cdef NP_FLOAT_t[:,:] L = -np.ones([T+1,K_ext], dtype=NP_FLOAT) 22 | cdef NP_INT_t[:,:] P = -np.ones([T+1,K_ext], dtype=NP_INT) 23 | L[0,0] = 0 24 | P[0,0] = 0 25 | 26 | cdef int opt_label 27 | cdef double opt_value 28 | cdef int j,t,s 29 | cdef NP_FLOAT_t[:] Lt 30 | cdef NP_INT_t[:] Pt 31 | for t in range(1,T+1): 32 | Lt = L[t-1,:] 33 | Pt = P[t-1,:] 34 | for k in range(K_ext): 35 | s = get_step(k) 36 | 37 | opt_label = -1 38 | 39 | j = k 40 | if (opt_label==-1 or opt_value>Lt[j]) and Pt[j]!=-1 and (s==0 or not exactly_one): 41 | opt_label = j 42 | opt_value = Lt[j] 43 | 44 | j = k-1 45 | if j>=0 and (opt_label==-1 or opt_value>Lt[j]) and Pt[j]!=-1: 46 | opt_label = j 47 | opt_value = L[t-1][j] 48 | 49 | if s!=0: 50 | j = k-2 51 | if j>=0 and (opt_label==-1 or opt_value>Lt[j]) and Pt[j]!=-1: 52 | opt_label = j 53 | opt_value = Lt[j] 54 | 55 | if s!=0: 56 | L[t,k] = opt_value + C[t-1][s-1] 57 | else: 58 | L[t,k] = opt_value + bg_cost 59 | P[t,k] = opt_label 60 | 61 | for t in range(T): 62 | for k in range(K): 63 | Y[t,k] = 0 64 | if (L[T,K_ext-1] < L[T,K_ext-2] or (P[T,K_ext-2]==-1)): 65 | k = K_ext-1 66 | else: 67 | k = K_ext-2 68 | for t in range(T,0,-1): 69 | s = get_step(k) 70 | if s > 0: 71 | Y[t-1,s-1] = 1 72 | k = P[t,k] -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import torch.nn as nn 7 | 8 | class Model(nn.Module): 9 | def __init__(self, d, M, A, q): 10 | super(Model, self).__init__() 11 | self.fc = nn.Linear(d,M) 12 | self.m = nn.Dropout(p=q) 13 | self.A = A 14 | 15 | def forward(self, x, task): 16 | return self.fc(self.m(x)).matmul(self.A[task]) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'dp', 7 | ext_modules = cythonize("dp.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | from model import Model 7 | from data import * 8 | from args import parse_args 9 | from dp import dp 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | 15 | class Loss(nn.Module): 16 | def __init__(self, lambd): 17 | super(Loss, self).__init__() 18 | self.lambd = lambd 19 | self.lsm = nn.LogSoftmax(dim=1) 20 | 21 | def forward(self, O, Y, C): 22 | return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum() 23 | 24 | def uniform_assignment(T,K): 25 | stepsize = float(T) / K 26 | y = th.zeros(T,K) 27 | for k in range(K): 28 | t = round(stepsize*(k+0.5)) 29 | y[t,k] = 1 30 | return y 31 | 32 | def get_recalls(Y_true, Y_pred): 33 | step_match = {task: 0 for task in Y_true.keys()} 34 | step_total = {task: 0 for task in Y_true.keys()} 35 | for task,ys_true in Y_true.items(): 36 | ys_pred = Y_pred[task] 37 | for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())): 38 | y_true = ys_true[vid] 39 | y_pred = ys_pred[vid] 40 | step_total[task] += (y_true.sum(axis=0)>0).sum() 41 | step_match[task] += (y_true*y_pred).sum() 42 | recalls = {task: step_match[task] / n for task,n in step_total.items()} 43 | return recalls 44 | 45 | args = parse_args() 46 | 47 | task_vids = get_vids(args.video_csv_path) 48 | val_vids = get_vids(args.val_csv_path) 49 | task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for task,vids in task_vids.items()} 50 | 51 | primary_info = read_task_info(args.primary_path) 52 | test_tasks = set(primary_info['steps'].keys()) 53 | if args.use_related: 54 | related_info = read_task_info(args.related_path) 55 | task_steps = {**primary_info['steps'], **related_info['steps']} 56 | n_steps = {**primary_info['n_steps'], **related_info['n_steps']} 57 | else: 58 | task_steps = primary_info['steps'] 59 | n_steps = primary_info['n_steps'] 60 | all_tasks = set(n_steps.keys()) 61 | task_vids = {task: vids for task,vids in task_vids.items() if task in all_tasks} 62 | 63 | A, M = get_A(task_steps, share=args.share) 64 | 65 | if args.use_gpu: 66 | A = {task: a.cuda() for task, a in A.items()} 67 | 68 | 69 | train_vids, test_vids = random_split(task_vids, test_tasks, args.n_train) 70 | 71 | trainset = CrossTaskDataset(train_vids, n_steps, args.features_path, args.constraints_path) 72 | trainloader = DataLoader(trainset, 73 | batch_size = args.batch_size, 74 | num_workers = args.num_workers, 75 | shuffle = True, 76 | drop_last = True, 77 | collate_fn = lambda batch: batch, 78 | ) 79 | testset = CrossTaskDataset(test_vids, n_steps, args.features_path, args.constraints_path) 80 | testloader = DataLoader(testset, 81 | batch_size = args.batch_size, 82 | num_workers = args.num_workers, 83 | shuffle = False, 84 | drop_last = False, 85 | collate_fn = lambda batch: batch, 86 | ) 87 | 88 | net = Model(args.d, M, A, args.q).cuda() if args.use_gpu else Model(args.d, M, A, args.q) 89 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 90 | loss_fn = Loss(args.lambd) 91 | 92 | # initialize with uniform step assignment 93 | Y = {} 94 | for batch in trainloader: 95 | for sample in batch: 96 | task = sample['task'] 97 | vid = sample['vid'] 98 | K = n_steps[task] 99 | T = sample['X'].shape[0] 100 | if task not in Y: 101 | Y[task] = {} 102 | y = uniform_assignment(T,K) 103 | Y[task][vid] = y.cuda() if args.use_gpu else y 104 | 105 | def train_epoch(pretrain=False): 106 | cumloss = 0. 107 | for batch in trainloader: 108 | for sample in batch: 109 | vid = sample['vid'] 110 | task = sample['task'] 111 | X = sample['X'].cuda() if args.use_gpu else sample['X'] 112 | C = sample['C'].cuda() if args.use_gpu else sample['C'] 113 | if pretrain: 114 | # picking random assignment, that satisfies the constraints 115 | O = np.random.rand(X.size()[0],n_steps[task]) + C.cpu().numpy() 116 | y = np.zeros(Y[task][vid].shape,dtype=np.float32) 117 | dp(y,O.astype(np.float32),exactly_one=True) 118 | Y[task][vid].data = th.tensor(y,dtype=th.float).cuda() if args.use_gpu else th.tensor(y,dtype=th.float) 119 | else: 120 | # updating assignment 121 | O = net(X, task) 122 | # y = th.tensor(Y[task][vid].data,requires_grad=True) 123 | y = Y[task][vid].requires_grad_(True) 124 | loss = loss_fn(O, y, C) 125 | param_grads = th.autograd.grad(loss, net.parameters(), create_graph=True, only_inputs=True) 126 | F = loss 127 | for g in param_grads: 128 | F -= 0.5*args.lr*(g**2).sum() 129 | Y_grad = th.autograd.grad(F,[y], only_inputs=True) 130 | y = np.zeros(Y[task][vid].size(),dtype=np.float32) 131 | dp(y,Y_grad[0].cpu().numpy()) 132 | Y[task][vid].requires_grad_(False) 133 | Y[task][vid].data = th.tensor(y,dtype=th.float).cuda() if args.use_gpu else th.tensor(y,dtype=th.float) 134 | 135 | # updating model parameters 136 | O = net(X, task) 137 | loss = loss_fn(O,Y[task][vid],C) 138 | loss.backward() 139 | cumloss += loss.item() 140 | optimizer.step() 141 | net.zero_grad() 142 | return cumloss 143 | 144 | def eval(): 145 | net.eval() 146 | lsm = nn.LogSoftmax(dim=1) 147 | Y_pred = {} 148 | Y_true = {} 149 | for batch in testloader: 150 | for sample in batch: 151 | vid = sample['vid'] 152 | task = sample['task'] 153 | X = sample['X'].cuda() if args.use_gpu else sample['X'] 154 | O = lsm(net(X, task)) 155 | y = np.zeros(O.size(),dtype=np.float32) 156 | dp(y,-O.detach().cpu().numpy()) 157 | if task not in Y_pred: 158 | Y_pred[task] = {} 159 | Y_pred[task][vid] = y 160 | annot_path = os.path.join(args.annotation_path,task+'_'+vid+'.csv') 161 | if os.path.exists(annot_path): 162 | if task not in Y_true: 163 | Y_true[task] = {} 164 | Y_true[task][vid] = read_assignment(*y.shape, annot_path) 165 | recalls = get_recalls(Y_true, Y_pred) 166 | for task,rec in recalls.items(): 167 | print('Task {0}. Recall = {1:0.3f}'.format(task, rec)) 168 | avg_recall = np.mean(list(recalls.values())) 169 | print ('Recall: {0:0.3f}'.format(avg_recall)) 170 | net.train() 171 | 172 | print ('Training...') 173 | net.train() 174 | for epoch in range(args.pretrain_epochs): 175 | cumloss = train_epoch(pretrain=True) 176 | print ('Epoch {0}. Loss={1:0.2f}'.format(epoch+1, cumloss)) 177 | for epoch in range(args.epochs): 178 | cumloss = train_epoch() 179 | print ('Epoch {0}. Loss={1:0.2f}'.format(args.pretrain_epochs+epoch+1, cumloss)) 180 | 181 | print ('Evaluating...') 182 | eval() 183 | --------------------------------------------------------------------------------