├── .gitignore ├── mixup.py ├── environment.yml ├── egtea.py ├── models_av.py ├── epic_kitchens.py ├── embeddings.py ├── models_lm.py ├── utils.py ├── corpus.py ├── test_av.py ├── README.md ├── test_av_lm.py ├── train_lm.py ├── train_av.py ├── transformers.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /experiments/ 3 | /wandb/ 4 | /__pycache__/ 5 | scores/ 6 | 7 | *.pt 8 | *.sh -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def soft_cross_entropy(pred, soft_targets): 7 | logsoftmax = torch.nn.LogSoftmax(dim=1) 8 | return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1)) 9 | 10 | 11 | def mixup_data(x, y, alpha=1.0): 12 | '''Returns mixed inputs, pairs of targets, and lambda''' 13 | if alpha > 0: 14 | lam = np.random.beta(alpha, alpha) 15 | else: 16 | lam = 1 17 | 18 | batch_size = x.size(0) 19 | 20 | index = torch.randperm(batch_size).cuda() 21 | mixed_x = lam * x + (1 - lam) * x[index, :] 22 | if isinstance(y, dict): 23 | y_a = {} 24 | y_b = {} 25 | y_a['verb'], y_b['verb'] = y['verb'], y['verb'][index] 26 | y_a['noun'], y_b['noun'] = y['noun'], y['noun'][index] 27 | else: 28 | y_a, y_b = y, y[index] 29 | return mixed_x, y_a, y_b, lam 30 | 31 | 32 | def mixup_data_and_targets(x, y, alpha=1.0): 33 | '''Returns mixed inputs, pairs of targets, and lambda''' 34 | if alpha > 0: 35 | lam = np.random.beta(alpha, alpha) 36 | else: 37 | lam = 1 38 | 39 | batch_size = x.size(0) 40 | 41 | index = torch.randperm(batch_size).cuda() 42 | mixed_x = lam * x + (1 - lam) * x[index, :] 43 | if isinstance(y, dict): 44 | mixed_y = {} 45 | y['verb'] = F.one_hot(y['verb'], num_classes=97).float() 46 | y['noun'] = F.one_hot(y['noun'], num_classes=300).float() 47 | y['verb'] = (1 - 0.05) * y['verb'] + (0.05 / y['verb'].shape[1]) 48 | y['noun'] = (1 - 0.05) * y['noun'] + (0.05 / y['noun'].shape[1]) 49 | mixed_y['verb'] = lam * y['verb'] + (1 - lam) * y['verb'][index] 50 | mixed_y['noun'] = lam * y['noun'] + (1 - lam) * y['noun'][index] 51 | else: 52 | y = F.one_hot(y, num_classes=10).float() 53 | mixed_y = lam * y + (1 - lam) * y[index] 54 | return mixed_x, mixed_y 55 | 56 | 57 | def mixup_criterion(criterion, pred, y_a, y_b, lam, weights=None): 58 | loss_a = criterion(pred, y_a) 59 | if weights is not None: 60 | loss_a = loss_a * weights 61 | loss_a = loss_a.sum(1) 62 | loss_a = loss_a.mean() 63 | loss_b = criterion(pred, y_b) 64 | if weights is not None: 65 | loss_b = loss_b * weights 66 | loss_b = loss_b.sum(1) 67 | loss_b = loss_b.mean() 68 | return lam * loss_a + (1 - lam) * loss_b 69 | 70 | 71 | def mixup_accuracy(output, target_a, target_b, lam, topk=(1,)): 72 | """Computes the precision@k for the specified values of k""" 73 | maxk = max(topk) 74 | batch_size = target_a.size(0) 75 | 76 | _, pred = output.topk(maxk, 1, True, True) 77 | pred = pred.t() 78 | correct = lam * pred.eq(target_a.view(1, -1).expand_as(pred)) \ 79 | + (1 - lam) * pred.eq(target_b.view(1, -1).expand_as(pred)) 80 | res = [] 81 | for k in topk: 82 | correct_k = correct[:k].view(-1).to(torch.float32).sum(0) 83 | res.append(float(correct_k.mul_(100.0 / batch_size))) 84 | return tuple(res) 85 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mtcn 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8 10 | - ca-certificates=2021.4.13 11 | - certifi=2020.12.5 12 | - cudatoolkit=10.2.89 13 | - cycler=0.10.0 14 | - dbus=1.13.18 15 | - expat=2.3.0 16 | - ffmpeg=4.3 17 | - fontconfig=2.13.1 18 | - freetype=2.10.4 19 | - glib=2.68.1 20 | - gmp=6.2.1 21 | - gnutls=3.6.15 22 | - gst-plugins-base=1.14.0 23 | - gstreamer=1.14.0 24 | - h5py=2.10.0 25 | - hdf5=1.10.6 26 | - icu=58.2 27 | - intel-openmp=2021.2.0 28 | - joblib=1.0.1 29 | - jpeg=9b 30 | - kiwisolver=1.3.1 31 | - lame=3.100 32 | - lcms2=2.12 33 | - ld_impl_linux-64=2.33.1 34 | - libblas=3.9.0 35 | - libcblas=3.9.0 36 | - libffi=3.3 37 | - libgcc-ng=9.1.0 38 | - libgfortran-ng=7.3.0 39 | - libiconv=1.15 40 | - libidn2=2.3.0 41 | - libpng=1.6.37 42 | - libstdcxx-ng=9.1.0 43 | - libtasn1=4.16.0 44 | - libtiff=4.1.0 45 | - libunistring=0.9.10 46 | - libuuid=1.0.3 47 | - libuv=1.40.0 48 | - libxcb=1.14 49 | - libxml2=2.9.10 50 | - lz4-c=1.9.3 51 | - matplotlib=3.3.4 52 | - matplotlib-base=3.3.4 53 | - mkl=2021.2.0 54 | - mkl-service=2.3.0 55 | - mkl_fft=1.3.0 56 | - mkl_random=1.2.1 57 | - ncurses=6.2 58 | - nettle=3.7.2 59 | - ninja=1.10.2 60 | - numpy=1.20.1 61 | - numpy-base=1.20.1 62 | - olefile=0.46 63 | - openh264=2.1.0 64 | - openssl=1.1.1k 65 | - pandas=1.2.4 66 | - pcre=8.44 67 | - pillow=8.2.0 68 | - pip=21.0.1 69 | - pyparsing=2.4.7 70 | - pyqt=5.9.2 71 | - python=3.8.8 72 | - python-dateutil=2.8.1 73 | - python_abi=3.8 74 | - pytorch=1.8.1=py3.8_cuda10.2_cudnn7.6.5_0 75 | - pytz=2021.1 76 | - qt=5.9.7 77 | - readline=8.1 78 | - scikit-learn=0.23.2 79 | - scipy=1.6.2 80 | - setuptools=52.0.0 81 | - sip=4.19.13 82 | - six=1.15.0 83 | - sqlite=3.35.4 84 | - threadpoolctl=2.1.0 85 | - tk=8.6.10 86 | - torchaudio=0.8.1=py38 87 | - torchvision=0.9.1=py38_cu102 88 | - tornado=6.1 89 | - typing_extensions=3.7.4.3 90 | - wheel=0.36.2 91 | - xz=5.2.5 92 | - zlib=1.2.11 93 | - zstd=1.4.9 94 | - pip: 95 | - chardet==4.0.0 96 | - click==7.1.2 97 | - configparser==5.0.2 98 | - docker-pycreds==0.4.0 99 | - docopt==0.6.2 100 | - einops==0.3.0 101 | - epic-kitchens==1.7.1 102 | - gitdb==4.0.7 103 | - gitpython==3.1.14 104 | - gulpio==540.66 105 | - idna==2.10 106 | - jinja2==2.11.3 107 | - markupsafe==1.1.1 108 | - opencv-python==4.5.2.52 109 | - pathtools==0.1.2 110 | - promise==2.3 111 | - protobuf==3.16.0 112 | - psutil==5.8.0 113 | - pyyaml==5.4.1 114 | - requests==2.25.1 115 | - sentry-sdk==1.1.0 116 | - sh==1.14.1 117 | - shortuuid==1.0.1 118 | - smmap==4.0.0 119 | - subprocess32==3.5.4 120 | - tensorboardx==2.2 121 | - tqdm==4.60.0 122 | - urllib3==1.26.4 123 | - wandb==0.10.30 124 | - word2vec==0.11.1 125 | -------------------------------------------------------------------------------- /egtea.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import pandas as pd 4 | import numpy as np 5 | import h5py 6 | 7 | 8 | class Egtea(data.Dataset): 9 | def __init__(self, 10 | hdf5_path, 11 | labels_pickle, 12 | visual_feature_dim=2304, 13 | audio_feature_dim=None, 14 | window_len=5, 15 | num_clips=10, 16 | clips_mode='random', 17 | labels_mode='center_action'): 18 | self.hdf5_dataset = None 19 | self.hdf5_path = hdf5_path 20 | self.df_labels = pd.read_pickle(labels_pickle) 21 | self.visual_feature_dim = visual_feature_dim 22 | self.window_len = window_len 23 | self.num_clips = num_clips 24 | assert clips_mode in ['all', 'random'], \ 25 | "Labels mode not supported. Choose from ['all', 'random']" 26 | assert labels_mode in ['all', 'center_action'], \ 27 | "Labels mode not supported. Choose from ['all', 'center_action']" 28 | self.clips_mode = clips_mode 29 | self.labels_mode = labels_mode 30 | 31 | def __getitem__(self, index): 32 | if self.hdf5_dataset is None: 33 | self.hdf5_dataset = h5py.File(self.hdf5_path, 'r') 34 | num_clips = self.num_clips if self.clips_mode == 'all' else 1 35 | data = torch.zeros((self.window_len * num_clips, self.visual_feature_dim)) 36 | 37 | clip_name = self.df_labels.iloc[index]['clip_name'] 38 | video_name = self.df_labels.iloc[index]['video_name'] 39 | df_idx = self.df_labels.iloc[index].name 40 | df_sorted_video = self.df_labels[self.df_labels['video_name'] == video_name].sort_values('start_frame') 41 | idx = df_sorted_video.index.get_loc(df_idx) 42 | start = idx - self.window_len // 2 43 | end = idx + self.window_len // 2 + 1 44 | sequence_range = np.clip(np.arange(start, end), 0, df_sorted_video.shape[0] - 1) 45 | sequence_clip_names = df_sorted_video.iloc[sequence_range]['clip_name'].tolist() 46 | 47 | if self.clips_mode == 'random': 48 | for i in range(self.window_len): 49 | clip_idx = np.random.randint(self.num_clips) 50 | data[i] = torch.from_numpy( 51 | self.hdf5_dataset['visual_features/' + str(sequence_clip_names[i])][clip_idx]) 52 | else: 53 | for j in range(self.num_clips): 54 | for i in range(self.window_len): 55 | data[i + j * self.window_len] = torch.from_numpy( 56 | self.hdf5_dataset['visual_features/' + sequence_clip_names[i]][j]) 57 | 58 | if self.labels_mode == "all": 59 | label = torch.from_numpy(df_sorted_video.iloc[sequence_range]['action_idx'].values) 60 | # Concatenate the labels of the center action in the end to be classified by the summary embedding 61 | label = torch.cat([label, label[self.window_len // 2].unsqueeze(0)]) 62 | else: 63 | # Center action 64 | label = torch.tensor(df_sorted_video.iloc[idx]['action_idx']) 65 | 66 | return data, label, clip_name 67 | 68 | def __len__(self): 69 | return self.df_labels.shape[0] 70 | -------------------------------------------------------------------------------- /models_av.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from embeddings import FeatureEmbedding 4 | from transformers import TransformerEncoder, TransformerEncoderLayer 5 | 6 | 7 | class MTCN_AV(nn.Module): 8 | def __init__(self, 9 | num_class, 10 | seq_len=5, 11 | num_clips=10, 12 | visual_input_dim=2304, 13 | audio_input_dim=2304, 14 | d_model=512, 15 | dim_feedforward=2048, 16 | nhead=8, 17 | num_layers=6, 18 | dropout=0.1, 19 | classification_mode='summary', 20 | audio=True): 21 | super(MTCN_AV, self).__init__() 22 | self.num_class = num_class 23 | self.seq_len = seq_len 24 | self.num_clips = num_clips 25 | self.visual_input_dim = visual_input_dim 26 | self.audio_input_dim = audio_input_dim 27 | self.d_model = d_model 28 | self.dim_feedforward = dim_feedforward 29 | self.nhead = nhead 30 | self.num_layers = num_layers 31 | self.dropout = dropout 32 | print("Building Transformer with {}-D, {} heads, and {} layers".format(self.d_model, 33 | self.nhead, 34 | self.num_layers)) 35 | assert classification_mode in ['all', 'summary'], \ 36 | "Classification mode not supported. Choose from ['all', 'summary']" 37 | self.classification_mode = classification_mode 38 | print("Classification mode: {}".format(self.classification_mode)) 39 | self.audio = audio 40 | self._create_model() 41 | 42 | def _create_model(self): 43 | self.feature_embedding = FeatureEmbedding(self.seq_len, 44 | self.num_clips, 45 | self.visual_input_dim, 46 | self.audio_input_dim, 47 | self.d_model, 48 | self.audio, 49 | not isinstance(self.num_class, (list, tuple))) 50 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, 51 | nhead=self.nhead, 52 | dim_feedforward=self.dim_feedforward, 53 | dropout=self.dropout) 54 | self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=self.num_layers) 55 | if isinstance(self.num_class, (list, tuple)): 56 | self.fc_verb = nn.Linear(self.d_model, self.num_class[0]) 57 | self.fc_noun = nn.Linear(self.d_model, self.num_class[1]) 58 | else: 59 | self.fc_action = nn.Linear(self.d_model, self.num_class) 60 | 61 | def forward(self, inputs, extract_attn_weights=False): 62 | # Project audio and visual features to lower dim and add positional, modality, and summary embeddings 63 | x = self.feature_embedding(inputs) 64 | if extract_attn_weights: 65 | x, attn_weights = self.transformer_encoder(x) 66 | x = x.transpose(0, 1).contiguous() 67 | else: 68 | x, _ = self.transformer_encoder(x) 69 | x = x.transpose(0, 1).contiguous() 70 | if isinstance(self.num_class, (list, tuple)): 71 | if self.classification_mode == 'all': 72 | output_verb_av = self.fc_verb(x[:, :-2, :]).transpose(1, 2).contiguous() 73 | output_noun_av = self.fc_noun(x[:, :-2, :]).transpose(1, 2).contiguous() 74 | output_verb_ve = self.fc_verb(x[:, -2, :]).unsqueeze(2) 75 | output_noun_no = self.fc_noun(x[:, -1, :]).unsqueeze(2) 76 | output_verb = torch.cat([output_verb_av, output_verb_ve], dim=2) 77 | output_noun = torch.cat([output_noun_av, output_noun_no], dim=2) 78 | else: 79 | output_verb = self.fc_verb(x[:, -2, :]) 80 | output_noun = self.fc_noun(x[:, -1, :]) 81 | output = (output_verb, output_noun) 82 | else: 83 | if self.classification_mode == 'all': 84 | output = self.fc_action(x).transpose(1, 2).contiguous() 85 | else: 86 | output = self.fc_action(x[:, -1, :]) 87 | if extract_attn_weights: 88 | return output, attn_weights 89 | else: 90 | return output 91 | -------------------------------------------------------------------------------- /epic_kitchens.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import pandas as pd 4 | import numpy as np 5 | import h5py 6 | 7 | 8 | class EpicKitchens(data.Dataset): 9 | def __init__(self, 10 | hdf5_path, 11 | labels_pickle, 12 | visual_feature_dim=2304, 13 | audio_feature_dim=2304, 14 | window_len=5, 15 | num_clips=10, 16 | clips_mode='random', 17 | labels_mode='center_action'): 18 | self.hdf5_dataset = None 19 | self.hdf5_path = hdf5_path 20 | self.df_labels = pd.read_pickle(labels_pickle) 21 | self.visual_feature_dim = visual_feature_dim 22 | self.audio_feature_dim = audio_feature_dim 23 | self.window_len = window_len 24 | self.num_clips = num_clips 25 | assert clips_mode in ['all', 'random'], \ 26 | "Labels mode not supported. Choose from ['all', 'random']" 27 | assert labels_mode in ['all', 'center_action'], \ 28 | "Labels mode not supported. Choose from ['all', 'center_action']" 29 | self.clips_mode = clips_mode 30 | self.labels_mode = labels_mode 31 | 32 | def __getitem__(self, index): 33 | if self.hdf5_dataset is None: 34 | self.hdf5_dataset = h5py.File(self.hdf5_path, 'r') 35 | num_clips = self.num_clips if self.clips_mode == 'all' else 1 36 | data = torch.zeros((2 * self.window_len * num_clips, max(self.visual_feature_dim, self.audio_feature_dim))) 37 | 38 | narration_id = self.df_labels.iloc[index].name 39 | video_id = self.df_labels.iloc[index]['video_id'] 40 | df_sorted_video = self.df_labels[self.df_labels['video_id'] == video_id].sort_values('start_timestamp') 41 | idx = df_sorted_video.index.get_loc(narration_id) 42 | start = idx - self.window_len // 2 43 | end = idx + self.window_len // 2 + 1 44 | sequence_range = np.clip(np.arange(start, end), 0, df_sorted_video.shape[0] - 1) 45 | sequence_narration_ids = df_sorted_video.iloc[sequence_range].index.tolist() 46 | 47 | if self.clips_mode == 'random': 48 | for i in range(self.window_len): 49 | clip_idx = np.random.randint(self.num_clips) 50 | data[i][:self.visual_feature_dim] = torch.from_numpy( 51 | self.hdf5_dataset['visual_features/' + str(sequence_narration_ids[i])][clip_idx]) 52 | data[self.window_len + i][:self.audio_feature_dim] = torch.from_numpy( 53 | self.hdf5_dataset['audio_features/' + str(sequence_narration_ids[i])][clip_idx]) 54 | else: 55 | for i in range(self.window_len): 56 | for j in range(self.num_clips): 57 | data[i * self.num_clips + j][:self.visual_feature_dim] = torch.from_numpy( 58 | self.hdf5_dataset['visual_features/' + str(sequence_narration_ids[i])][j]) 59 | data[self.window_len * self.num_clips + i * self.num_clips + j][:self.audio_feature_dim] = torch.from_numpy( 60 | self.hdf5_dataset['audio_features/' + str(sequence_narration_ids[i])][j]) 61 | 62 | if self.labels_mode == "all": 63 | verbs = torch.from_numpy(df_sorted_video.iloc[sequence_range]['verb_class'].values) \ 64 | if 'verb_class' in df_sorted_video.columns else torch.full((self.window_len,), -1) 65 | nouns = torch.from_numpy(df_sorted_video.iloc[sequence_range]['noun_class'].values) \ 66 | if 'noun_class' in df_sorted_video.columns else torch.full((self.window_len,), -1) 67 | # Replicate sequence of labels x2, 1 for video sequence and 1 audio sequence 68 | verbs = verbs.repeat(2) 69 | nouns = nouns.repeat(2) 70 | # Concatenate the labels of the center action in the end to be classified by the summary embedding 71 | verbs = torch.cat([verbs, verbs[self.window_len // 2].unsqueeze(0)]) 72 | nouns = torch.cat([nouns, nouns[self.window_len // 2].unsqueeze(0)]) 73 | label = {'verb': verbs, 'noun': nouns} 74 | else: 75 | # Center action 76 | verb = torch.tensor(df_sorted_video.iloc[idx]['verb_class']) \ 77 | if 'verb_class' in df_sorted_video.columns else torch.full((1,), -1) 78 | noun = torch.tensor(df_sorted_video.iloc[idx]['noun_class']) \ 79 | if 'noun_class' in df_sorted_video.columns else torch.full((1,), -1) 80 | label = {'verb': verb, 'noun': noun} 81 | 82 | return data, label, narration_id 83 | 84 | def __len__(self): 85 | return self.df_labels.shape[0] 86 | -------------------------------------------------------------------------------- /embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.init import normal_ 4 | 5 | 6 | class FeatureEmbedding(nn.Module): 7 | def __init__(self, seq_len, num_clips, visual_input_dim, audio_input_dim, d_model, audio, embed_actions): 8 | super(FeatureEmbedding, self).__init__() 9 | self.seq_len = seq_len 10 | self.num_clips = num_clips 11 | self.visual_input_dim = visual_input_dim 12 | self.audio_input_dim = audio_input_dim 13 | self.visual_projection = nn.Linear(visual_input_dim, d_model) 14 | self.visual_relu = nn.ReLU() 15 | if audio: 16 | self.audio_projection = nn.Linear(audio_input_dim, d_model) 17 | self.audio_relu = nn.ReLU() 18 | self.num_cls_embeddings = 1 if embed_actions else 2 19 | self.positional_embedding = nn.Parameter(torch.empty((1, seq_len + self.num_cls_embeddings, d_model), requires_grad=True)) 20 | normal_(self.positional_embedding, std=0.001) 21 | # When there is no audio (EGTEA), there is no need for modality embeddings 22 | # as there are only visual inputs, so there is no need for discrimination 23 | # between visual/audio inputs. 24 | if audio: 25 | self.visual_embedding = nn.Parameter(torch.empty((1, 1, d_model), requires_grad=True)) 26 | normal_(self.visual_embedding, std=0.001) 27 | self.audio_embedding = nn.Parameter(torch.empty((1, 1, d_model), requires_grad=True)) 28 | normal_(self.audio_embedding, std=0.001) 29 | if not embed_actions: 30 | self.verb_embedding = nn.Parameter(torch.empty((1, 1, d_model), requires_grad=True)) 31 | normal_(self.verb_embedding, std=0.001) 32 | self.noun_embedding = nn.Parameter(torch.empty((1, 1, d_model), requires_grad=True)) 33 | normal_(self.noun_embedding, std=0.001) 34 | else: 35 | self.action_embedding = nn.Parameter(torch.empty((1, 1, d_model), requires_grad=True)) 36 | normal_(self.action_embedding, std=0.001) 37 | self.dropout = nn.Dropout(p=0.5) 38 | self.dropout_v = nn.Dropout(p=0.5) 39 | self.dropout_a = nn.Dropout(p=0.5) 40 | self.audio = audio 41 | self.embed_actions = embed_actions 42 | 43 | def forward(self, inputs): 44 | # Project audio and visual features to a lower dim 45 | vis_embed = self.dropout_v(inputs[:, :self.seq_len * self.num_clips, :self.visual_input_dim]) 46 | if self.audio: 47 | aud_embed = self.dropout_a(inputs[:, self.seq_len * self.num_clips:, :self.audio_input_dim]) 48 | vis_embed = self.visual_projection(vis_embed) 49 | vis_embed = self.visual_relu(vis_embed) 50 | if self.audio: 51 | aud_embed = self.audio_projection(aud_embed) 52 | aud_embed = self.audio_relu(aud_embed) 53 | if self.audio: 54 | # Tag audio-visual inputs with positional and modality embeddings 55 | vis_embed = vis_embed + \ 56 | self.positional_embedding[:, :-self.num_cls_embeddings, :].repeat_interleave(self.num_clips, dim=1) + \ 57 | self.visual_embedding 58 | aud_embed = aud_embed + \ 59 | self.positional_embedding[:, :-self.num_cls_embeddings, :].repeat_interleave(self.num_clips, dim=1) + \ 60 | self.audio_embedding 61 | else: 62 | # Tag visual inputs with positional embeddings 63 | vis_embed = vis_embed + \ 64 | self.positional_embedding[:, :-self.num_cls_embeddings, :].repeat_interleave(self.num_clips, dim=1) 65 | if not self.embed_actions: 66 | # Tag verb/noun embeddings with positional embeddings 67 | verb_embed = self.verb_embedding + self.positional_embedding[:, -2, :] 68 | noun_embed = self.noun_embedding + self.positional_embedding[:, -1, :] 69 | verb_embed = verb_embed.expand(vis_embed.shape[0], -1, -1) 70 | noun_embed = noun_embed.expand(vis_embed.shape[0], -1, -1) 71 | else: 72 | # Tag action embedding with positional embeddings 73 | action_embed = self.action_embedding + self.positional_embedding[:, -1, :] 74 | action_embed = action_embed.expand(vis_embed.shape[0], -1, -1) 75 | if self.audio: 76 | if not self.embed_actions: 77 | seq = torch.cat([vis_embed, aud_embed, verb_embed, noun_embed], 1) 78 | else: 79 | seq = torch.cat([vis_embed, aud_embed, action_embed], 1) 80 | else: 81 | if not self.embed_actions: 82 | seq = torch.cat([vis_embed, verb_embed, noun_embed], 1) 83 | else: 84 | seq = torch.cat([vis_embed, action_embed], 1) 85 | seq = self.dropout(seq) 86 | seq = seq.transpose(0, 1).contiguous() 87 | return seq 88 | -------------------------------------------------------------------------------- /models_lm.py: -------------------------------------------------------------------------------- 1 | # Model 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers import TransformerEncoder, TransformerEncoderLayer 7 | 8 | 9 | class PositionalEncoding(nn.Module): 10 | """Inject some information about the relative or absolute position of the tokens 11 | in the sequence. The positional encodings have the same dimension as 12 | the embeddings, so that the two can be summed. Here, we use sine and cosine 13 | functions of different frequencies. 14 | .. math:: 15 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 16 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 17 | \text{where pos is the word position and i is the embed idx) 18 | Args: 19 | d_model: the embed dim (required). 20 | dropout: the dropout value (default=0.1). 21 | max_len: the max. length of the incoming sequence (default=5000). 22 | Examples: 23 | >>> pos_encoder = PositionalEncoding(d_model) 24 | """ 25 | 26 | def __init__(self, d_model, dropout=0.5, max_len=5000): 27 | super(PositionalEncoding, self).__init__() 28 | self.dropout = nn.Dropout(p=dropout) 29 | 30 | pe = torch.zeros(max_len, d_model) 31 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 32 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 33 | pe[:, 0::2] = torch.sin(position * div_term) 34 | pe[:, 1::2] = torch.cos(position * div_term) 35 | pe = pe.unsqueeze(0).transpose(0, 1) 36 | self.register_buffer('pe', pe) 37 | 38 | def forward(self, x): 39 | r"""Inputs of forward function 40 | Args: 41 | x: the sequence fed to the positional encoder model (required). 42 | Shape: 43 | x: [sequence length, batch size, embed dim] 44 | output: [sequence length, batch size, embed dim] 45 | Examples: 46 | >>> output = pos_encoder(x) 47 | """ 48 | 49 | x = x + self.pe[:x.size(0), :] 50 | return self.dropout(x) 51 | 52 | 53 | class MTCN_LM(nn.Module): 54 | def __init__(self, 55 | num_class, 56 | d_model=512, 57 | dim_feedforward=512, 58 | nhead=8, 59 | num_layers=4, 60 | dropout=0.1): 61 | super(MTCN_LM, self).__init__() 62 | self.num_class = num_class 63 | self.d_model = d_model 64 | self.dim_feedforward = dim_feedforward 65 | self.nhead = nhead 66 | self.num_layers = num_layers 67 | self.dropout = dropout 68 | 69 | self._create_model() 70 | 71 | def _create_model(self): 72 | self.pos_encoder = PositionalEncoding(self.d_model, dropout=0.1) 73 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, 74 | nhead=self.nhead, 75 | dim_feedforward=self.dim_feedforward, 76 | dropout=self.dropout) 77 | self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=self.num_layers) 78 | 79 | # (ntokens[0] + 1) and (ntokens[1] + 1) are MASK token for verb and noun respectively 80 | if isinstance(self.num_class, (list, tuple)): 81 | self.verb_encoder = nn.Embedding(self.num_class[0] + 1, self.d_model // 2) 82 | self.noun_encoder = nn.Embedding(self.num_class[1] + 1, self.d_model // 2) 83 | self.decoder = nn.Linear(self.d_model, self.num_class[0] + self.num_class[1]) 84 | else: 85 | self.num_class = int(self.num_class) 86 | self.encoder = nn.Embedding(self.num_class + 1, self.d_model) 87 | self.decoder = nn.Linear(self.d_model, self.num_class) 88 | print("Building Transformer with {}-D, {} heads, and {} layers".format(self.d_model, 89 | self.nhead, 90 | self.num_layers)) 91 | 92 | def forward(self, verb_input, noun_input=None): 93 | if isinstance(self.num_class, (list, tuple)): 94 | verb_src = self.verb_encoder(verb_input) 95 | noun_src = self.noun_encoder(noun_input) 96 | src = torch.cat([verb_src, noun_src], dim=-1) 97 | else: 98 | # For this option, the noun_input should be None 99 | assert noun_input == None 100 | src = self.encoder(verb_input) 101 | 102 | src *= math.sqrt(self.d_model) 103 | output = self.pos_encoder(src) 104 | output, _ = self.transformer_encoder(output) 105 | output = self.decoder(output) 106 | 107 | return output 108 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import shutil 5 | from pathlib import Path 6 | 7 | 8 | def accuracy(output, target, topk=(1,)): 9 | """Computes the precision@k for the specified values of k""" 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].reshape(-1).to(torch.float32).sum(0) 20 | res.append(float(correct_k.mul_(100.0 / batch_size))) 21 | return tuple(res) 22 | 23 | 24 | def multitask_accuracy(outputs, labels, topk=(1,)): 25 | """ 26 | Args: 27 | outputs: tuple(torch.FloatTensor), each tensor should be of shape 28 | [batch_size, class_count], class_count can vary on a per task basis, i.e. 29 | outputs[i].shape[1] can be different to outputs[j].shape[j]. 30 | labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size] 31 | topk: tuple(int), compute accuracy at top-k for the values of k specified 32 | in this parameter. 33 | Returns: 34 | tuple(float), same length at topk with the corresponding accuracy@k in. 35 | """ 36 | max_k = int(np.max(topk)) 37 | task_count = len(outputs) 38 | batch_size = labels[0].size(0) 39 | all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor) 40 | if torch.cuda.is_available(): 41 | all_correct = all_correct.cuda(device=0) 42 | for output, label in zip(outputs, labels): 43 | _, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True) 44 | # Flip batch_size, class_count as .view doesn't work on non-contiguous 45 | max_k_idx = max_k_idx.t() 46 | correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx)) 47 | all_correct.add_(correct_for_task) 48 | 49 | accuracies = [] 50 | for k in topk: 51 | all_tasks_correct = torch.ge(all_correct[:k].float().sum(0), task_count) 52 | accuracy_at_k = float(all_tasks_correct.float().sum(0) * 100.0 / batch_size) 53 | accuracies.append(accuracy_at_k) 54 | return tuple(accuracies) 55 | 56 | 57 | def save_checkpoint(state, is_best, output_dir, filename='checkpoint.pyth'): 58 | weights_dir = output_dir / Path('models') 59 | if not weights_dir.exists(): 60 | weights_dir.mkdir(parents=True) 61 | torch.save(state, weights_dir / filename) 62 | if is_best: 63 | shutil.copyfile(weights_dir / filename, 64 | weights_dir / 'model_best.pyth') 65 | 66 | 67 | class AverageMeter(object): 68 | """Computes and stores the average and current value""" 69 | def __init__(self): 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | 85 | ########################################################## 86 | ## BEAM SEARCH FUNCTIONS ## 87 | ########################################################## 88 | 89 | def beam_search_decoder(predictions, top_k = 3): 90 | #start with an empty sequence with zero score 91 | output_sequences = [([], 0)] 92 | 93 | #looping through all the predictions 94 | for token_probs in predictions: 95 | new_sequences = [] 96 | #append new tokens to old sequences and re-score 97 | for old_seq, old_score in output_sequences: 98 | for char_index in range(len(token_probs)): 99 | new_seq = old_seq + [char_index] 100 | #considering log-likelihood for scoring 101 | #new_score = old_score + math.log(token_probs[char_index]) 102 | new_score = old_score + token_probs[char_index] 103 | new_sequences.append((new_seq, new_score)) 104 | 105 | #sort all new sequences in the de-creasing order of their score 106 | output_sequences = sorted(new_sequences, key = lambda val: val[1], reverse = True) 107 | 108 | #select top-k based on score 109 | # *Note- best sequence is with the highest score 110 | output_sequences = output_sequences[:top_k] 111 | 112 | return output_sequences 113 | 114 | 115 | def get_topk(verb_sequence, noun_sequence, beam_size): 116 | # Conduct beam search on verb and noun individually - for Epic-kitchens 117 | verb_output = beam_search_decoder(verb_sequence, beam_size) 118 | noun_output = beam_search_decoder(noun_sequence, beam_size) 119 | 120 | return verb_output, noun_output 121 | 122 | 123 | def get_topk_action(action_sequence, beam_size): 124 | # Conduct beam search on action - for EGTEA 125 | action_output = beam_search_decoder(action_sequence, beam_size) 126 | 127 | return action_output 128 | 129 | 130 | def get_lmscore(verb_seq, noun_seq, model, num_gram, ntokens): 131 | # Calculate the LM score of the sequence for epic 132 | verb_score, noun_score = 0, 0 133 | verb_input = verb_seq.repeat(1, num_gram) 134 | noun_input = noun_seq.repeat(1, num_gram) 135 | verb_input[range(num_gram), range(num_gram)] = ntokens[0] 136 | noun_input[range(num_gram), range(num_gram)] = ntokens[1] 137 | 138 | with torch.no_grad(): 139 | output = model(verb_input, noun_input) 140 | verb_output = torch.nn.functional.log_softmax(output[..., :ntokens[0]], dim=-1) 141 | noun_output = torch.nn.functional.log_softmax(output[..., ntokens[0]:], dim=-1) 142 | verb_score = torch.sum(verb_output[range(num_gram), range(num_gram), verb_seq.reshape(-1)]).item() 143 | noun_score = torch.sum(noun_output[range(num_gram), range(num_gram), noun_seq.reshape(-1)]).item() 144 | 145 | return verb_score, noun_score 146 | 147 | 148 | def get_lmscore_action(action_seq, model, num_gram, ntokens): 149 | # Calculate the LM score of the sequence for egtea 150 | action_score = 0 151 | action_input = action_seq.repeat(1, num_gram) 152 | action_input[range(num_gram), range(num_gram)] = ntokens 153 | 154 | with torch.no_grad(): 155 | output = model(action_input, None) 156 | action_output = torch.nn.functional.log_softmax(output, dim=-1) 157 | action_score = torch.sum(action_output[range(num_gram), range(num_gram), action_seq.reshape(-1)]).item() 158 | 159 | return action_score 160 | -------------------------------------------------------------------------------- /corpus.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | from torch.utils.data import Dataset 7 | 8 | class Dictionary(object): 9 | def __init__(self): 10 | self.word2idx = {} 11 | self.idx2word = [] 12 | self.idx2count = {} 13 | 14 | def add_word(self, word): 15 | if word not in self.word2idx: 16 | self.idx2word.append(word) 17 | self.word2idx[word] = len(self.idx2word) - 1 18 | self.idx2count[len(self.idx2word) - 1] = 0 19 | return self.word2idx[word] 20 | 21 | def add_count(self, idx): 22 | self.idx2count[idx] += 1 23 | 24 | def __len__(self): 25 | return len(self.idx2word) 26 | 27 | 28 | class EpicCorpus(Dataset): 29 | def __init__(self, pickle_file, csvfiles, num_class, num_gram, train=True): 30 | self.verb_dict, self.noun_dict = Dictionary(), Dictionary() 31 | verb_csv, noun_csv = csvfiles[0], csvfiles[1] 32 | self.num_class = num_class 33 | self.num_gram = num_gram 34 | self.train = train 35 | 36 | assert num_gram >= 2 37 | 38 | # Update verb & noun dictionary, note that last token is '' token 39 | with open(verb_csv, 'r') as f: 40 | lines = f.readlines() 41 | for line in lines[1:]: 42 | idx, word = int(line.split(',')[0]), line.split(',')[1] 43 | self.verb_dict.add_word(word) 44 | self.verb_dict.add_word('') 45 | 46 | with open(noun_csv, 'r') as f: 47 | lines = f.readlines() 48 | for line in lines[1:]: 49 | idx, word = int(line.split(',')[0]), line.split(',')[1] 50 | self.noun_dict.add_word(word) 51 | self.noun_dict.add_word('') 52 | self.verbs, self.nouns = self.tokenize(pd.read_pickle(pickle_file)) 53 | 54 | def tokenize(self, df_labels): 55 | """Tokenizes a epic-kitchens file.""" 56 | # Parse the pandas file 57 | video_ids = sorted(list(set(df_labels['video_id']))) 58 | verb_idss, noun_idss = [], [] 59 | 60 | for video_id in video_ids: 61 | df_video = df_labels[df_labels['video_id'] == video_id] 62 | df_video = df_video.sort_values(by='start_frame') 63 | verb_class = list(df_video['verb_class']) 64 | noun_class = list(df_video['noun_class']) 65 | 66 | for verbidx in verb_class: 67 | self.verb_dict.add_count(verbidx) 68 | for nounidx in noun_class: 69 | self.noun_dict.add_count(nounidx) 70 | 71 | assert len(verb_class) == len(noun_class) 72 | for ii in range(len(verb_class) - self.num_gram + 1): 73 | verb_temp = [] 74 | noun_temp = [] 75 | for j in range(self.num_gram): 76 | verb_temp.append(verb_class[ii + j]) 77 | noun_temp.append(noun_class[ii + j]) 78 | verb_idss.append(torch.tensor(verb_temp).type(torch.int64)) 79 | noun_idss.append(torch.tensor(noun_temp).type(torch.int64)) 80 | 81 | verb_ids = torch.stack(verb_idss, dim=0) 82 | noun_ids = torch.stack(noun_idss, dim=0) 83 | 84 | assert verb_ids.shape[0] == noun_ids.shape[0] 85 | return verb_ids, noun_ids 86 | 87 | def __len__(self): 88 | return len(self.verbs) 89 | 90 | def __getitem__(self, index): 91 | verb, noun = self.verbs[index], self.nouns[index] 92 | verb_input, noun_input = verb.clone().detach(), noun.clone().detach() 93 | 94 | if self.train: 95 | verb_mask_pos = np.random.choice(list(range(self.num_gram))) 96 | noun_mask_pos = verb_mask_pos 97 | 98 | verb_input[verb_mask_pos] = self.verb_dict.word2idx[''] 99 | noun_input[noun_mask_pos] = self.noun_dict.word2idx[''] 100 | 101 | else: 102 | # For evaluating, test only the centre action 103 | mask_pos = self.num_gram // 2 104 | verb_input[mask_pos] = self.verb_dict.word2idx[''] 105 | noun_input[mask_pos] = self.noun_dict.word2idx[''] 106 | 107 | data = {'verb_input': verb_input, 'verb_target': verb, 'noun_input': noun_input, 'noun_target' : noun} 108 | return data 109 | 110 | 111 | class EgteaCorpus(Dataset): 112 | def __init__(self, pickle_file, csvfiles, num_class, num_gram, train=True): 113 | self.action_dict = Dictionary() 114 | self.num_class = int(num_class) 115 | self.num_gram = num_gram 116 | self.train = train 117 | action_csv = csvfiles[0] 118 | 119 | assert num_gram >= 2 120 | 121 | # Update action dictionary, note that last token is '' token 122 | with open(action_csv, 'r') as f: 123 | lines = f.readlines() 124 | for line in lines: 125 | idx, word = int(line.split(',')[0]), line.split(',')[1] 126 | self.action_dict.add_word(word) 127 | self.action_dict.add_word('') 128 | self.actions = self.tokenize(pd.read_pickle(pickle_file)) 129 | 130 | def tokenize(self, df_labels): 131 | """Tokenizes a epic-kitchens file.""" 132 | # Parse the pandas file 133 | video_ids = sorted(list(set(df_labels['video_name']))) 134 | action_idss = [] 135 | 136 | for video_id in video_ids: 137 | df_video = df_labels[df_labels['video_name'] == video_id] 138 | df_video = df_video.sort_values(by='start_frame') 139 | action_class = list(df_video['action_idx']) 140 | 141 | for actionidx in action_class: 142 | self.action_dict.add_count(actionidx) 143 | 144 | for ii in range(len(action_class) - self.num_gram + 1): 145 | action_temp = [] 146 | for j in range(self.num_gram): 147 | action_temp.append(action_class[ii + j]) 148 | action_idss.append(torch.tensor(action_temp).type(torch.int64)) 149 | 150 | action_ids = torch.stack(action_idss, dim=0) 151 | 152 | return action_ids 153 | 154 | def __len__(self): 155 | return len(self.actions) 156 | 157 | def __getitem__(self, index): 158 | action = self.actions[index] 159 | action_input = action.clone().detach() 160 | 161 | if self.train: 162 | mask_pos = np.random.choice(list(range(self.num_gram))) 163 | action_input[mask_pos] = self.action_dict.word2idx[''] 164 | 165 | else: 166 | # For evaluating, test only the centre action 167 | mask_pos = self.num_gram // 2 168 | action_input[mask_pos] = self.action_dict.word2idx[''] 169 | 170 | data = {'input': action_input, 'target': action} 171 | return data 172 | 173 | -------------------------------------------------------------------------------- /test_av.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim 8 | from sklearn.metrics import confusion_matrix, accuracy_score 9 | 10 | from epic_kitchens import EpicKitchens 11 | from egtea import Egtea 12 | from models_av import MTCN_AV 13 | import pickle 14 | 15 | _DATASETS = {'epic': EpicKitchens, 'egtea': Egtea} 16 | _NUM_CLASSES = {'epic-55': [125, 352], 'epic-100': [97, 300], 'egtea': 106} 17 | 18 | 19 | def eval_video(data, net, device): 20 | data = data.to(device) 21 | # For EGTEA, we feed each of 10 clips of each action in the sequence independently 22 | # to the audio-visual transformer and average their predictions, while for EPIC-KITCHENS 23 | # we feed all 10 clips for each action in the sequence simultaneously to the audio-visual transformer 24 | if args.dataset == 'egtea': 25 | data = data.view(10, -1, data.shape[2]) 26 | if args.extract_attn_weights: 27 | rst, attn_weights = net(data, extract_attn_weights=args.extract_attn_weights) 28 | else: 29 | rst = net(data, extract_attn_weights=args.extract_attn_weights) 30 | if args.dataset == 'egtea': 31 | rst = torch.mean(rst, dim=0) 32 | 33 | if not isinstance(_NUM_CLASSES[args.dataset], list): 34 | if args.extract_attn_weights: 35 | return rst.cpu().numpy().squeeze(), attn_weights 36 | else: 37 | return rst.cpu().numpy().squeeze() 38 | else: 39 | if args.extract_attn_weights: 40 | return {'verb': rst[0].cpu().numpy().squeeze(), 41 | 'noun': rst[1].cpu().numpy().squeeze()},\ 42 | attn_weights 43 | else: 44 | return {'verb': rst[0].cpu().numpy().squeeze(), 45 | 'noun': rst[1].cpu().numpy().squeeze()} 46 | 47 | 48 | def evaluate_model(): 49 | 50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu", 0) 51 | net = MTCN_AV(_NUM_CLASSES[args.dataset], 52 | seq_len=args.seq_len, 53 | num_clips=10 if 'epic' in args.dataset else 1, 54 | visual_input_dim=args.visual_input_dim, 55 | audio_input_dim=args.audio_input_dim if args.dataset.split('-')[0] == 'epic' else None, 56 | d_model=args.d_model, 57 | dim_feedforward=args.dim_feedforward, 58 | nhead=args.nhead, 59 | num_layers=args.num_layers, 60 | dropout=args.dropout, 61 | classification_mode='summary', 62 | audio=not args.dataset == 'egtea') 63 | 64 | checkpoint = torch.load(args.checkpoint) 65 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 66 | 67 | net.load_state_dict(checkpoint['state_dict']) 68 | 69 | dataset = _DATASETS[args.dataset.split('-')[0]] 70 | test_loader = torch.utils.data.DataLoader( 71 | dataset(args.test_hdf5_path, 72 | args.test_pickle, 73 | visual_feature_dim=args.visual_input_dim, 74 | audio_feature_dim=args.audio_input_dim if args.dataset.split('-')[0] == 'epic' else None, 75 | window_len=args.seq_len, 76 | num_clips=10, 77 | clips_mode='all',), 78 | batch_size=1, shuffle=False, 79 | num_workers=args.workers, pin_memory=True) 80 | 81 | net = net.to(device) 82 | with torch.no_grad(): 83 | net.eval() 84 | results = [] 85 | if args.extract_attn_weights: 86 | attention_weights_dict = {} 87 | total_num = len(test_loader.dataset) 88 | 89 | proc_start_time = time.time() 90 | for i, (data, label, narration_id) in enumerate(test_loader): 91 | if args.extract_attn_weights: 92 | rst, attn_weights = eval_video(data, net, device) 93 | else: 94 | rst = eval_video(data, net, device) 95 | if not isinstance(_NUM_CLASSES[args.dataset], list): 96 | label_ = label.item() 97 | else: 98 | label_ = {k: v.item() for k, v in label.items()} 99 | results.append((rst, label_, narration_id)) 100 | if args.extract_attn_weights: 101 | attention_weights_dict[narration_id[0]] = attn_weights 102 | cnt_time = time.time() - proc_start_time 103 | print('video {} done, total {}/{}, average {} sec/video'.format( 104 | i, i + 1, total_num, float(cnt_time) / (i + 1))) 105 | if args.extract_attn_weights: 106 | return results, attention_weights_dict 107 | else: 108 | return results 109 | 110 | 111 | def print_accuracy(scores, labels): 112 | 113 | video_pred = [np.argmax(score) for score in scores] 114 | cf = confusion_matrix(labels, video_pred).astype(float) 115 | cls_cnt = cf.sum(axis=1) 116 | cls_hit = np.diag(cf) 117 | cls_cnt[cls_hit == 0] = 1 # to avoid divisions by zero 118 | cls_acc = cls_hit / cls_cnt 119 | 120 | acc = accuracy_score(labels, video_pred) 121 | 122 | print('Accuracy {:.02f}%'.format(acc * 100)) 123 | print('Average Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 124 | 125 | 126 | def save_scores(results, output): 127 | 128 | save_dict = {} 129 | if not isinstance(_NUM_CLASSES[args.dataset], list): 130 | scores = np.array([result[0] for result in results]) 131 | labels = np.array([result[1] for result in results]) 132 | save_dict['scores'] = scores 133 | save_dict['labels'] = labels 134 | else: 135 | keys = results[0][0].keys() 136 | save_dict = {k + '_output': np.array([result[0][k] for result in results]) for k in keys} 137 | save_dict['narration_id'] = np.array([result[2] for result in results]) 138 | 139 | with open(output, 'wb') as f: 140 | pickle.dump(save_dict, f) 141 | 142 | 143 | def main(): 144 | 145 | parser = argparse.ArgumentParser(description=('Test Audio-Visual Transformer on Sequence ' + 146 | 'of actions from untrimmed video')) 147 | parser.add_argument('--test_hdf5_path', type=Path) 148 | parser.add_argument('--test_pickle', type=Path) 149 | parser.add_argument('--dataset', choices=['epic-55', 'epic-100', 'egtea']) 150 | parser.add_argument('--checkpoint', type=Path) 151 | parser.add_argument('--seq_len', type=int, default=5) 152 | parser.add_argument('--visual_input_dim', type=int, default=2304) 153 | parser.add_argument('--audio_input_dim', type=int, default=2304) 154 | parser.add_argument('--d_model', type=int, default=512) 155 | parser.add_argument('--dim_feedforward', type=int, default=2048) 156 | parser.add_argument('--nhead', type=int, default=8) 157 | parser.add_argument('--num_layers', type=int, default=6) 158 | parser.add_argument('--dropout', type=float, default=0.1) 159 | parser.add_argument('--window_len', type=int, default=60) 160 | parser.add_argument('--extract_attn_weights', action='store_true') 161 | parser.add_argument('--output_dir', type=Path) 162 | parser.add_argument('--split') 163 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 164 | help='number of data loading workers (default: 4)') 165 | 166 | global args 167 | args = parser.parse_args() 168 | 169 | if args.extract_attn_weights: 170 | results, attention_weights_dict = evaluate_model() 171 | else: 172 | results = evaluate_model() 173 | if ('test' not in args.split and 'epic' in args.dataset) or 'epic' not in args.dataset: 174 | if isinstance(_NUM_CLASSES[args.dataset], list): 175 | keys = results[0][0].keys() 176 | for task in keys: 177 | print('Evaluation of {}'.format(task.upper())) 178 | print_accuracy([result[0][task] for result in results], 179 | [result[1][task] for result in results]) 180 | else: 181 | print_accuracy([result[0] for result in results], 182 | [result[1] for result in results]) 183 | 184 | output_dir = args.output_dir / Path('scores') 185 | if not output_dir.exists(): 186 | output_dir.mkdir(parents=True) 187 | save_scores(results, output_dir / Path(args.split+'.pkl')) 188 | 189 | if args.extract_attn_weights: 190 | attention_output_dir = args.output_dir / Path('attention') 191 | if not attention_output_dir.exists(): 192 | attention_output_dir.mkdir(parents=True) 193 | attention_output_file = attention_output_dir / Path(args.split+'.pkl') 194 | with open(attention_output_file, 'wb') as f: 195 | pickle.dump(attention_weights_dict, f) 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Temporal Context Network (MTCN) 2 | 3 | This repository implements the model proposed in the paper: 4 | 5 | Evangelos Kazakos, Jaesung Huh, Arsha Nagrani, Andrew Zisserman, Dima Damen, **With a Little Help from my Temporal Context: Multimodal Egocentric Action Recognition**, BMVC, 2021 6 | 7 | [Project webpage](https://ekazakos.github.io/MTCN-project/) 8 | 9 | [arXiv paper](https://arxiv.org/abs/2111.01024) 10 | 11 | ## Citing 12 | When using this code, kindly reference: 13 | 14 | ``` 15 | @INPROCEEDINGS{kazakos2021MTCN, 16 | author={Kazakos, Evangelos and Huh, Jaesung and Nagrani, Arsha and Zisserman, Andrew and Damen, Dima}, 17 | booktitle={British Machine Vision Conference (BMVC)}, 18 | title={With a Little Help from my Temporal Context: Multimodal Egocentric Action Recognition}, 19 | year={2021}} 20 | ``` 21 | 22 | ## NOTE 23 | Although we train MTCN using visual SlowFast features extracted from a model trained with video clips of 2s, at Table 3 of our paper and Table 1 of Appendix (Table 6 in the arXiv version) where we compare MTCN with SOTA, the results of SlowFast are from [1] where the model is trained with video clips of 1s. In the following table, we provide the results of SlowFast trained with 2s, for a direct comparison as we use this model to extract the visual features. 24 | 25 | ![alt text](https://ekazakos.github.io/files/slowfast.jpeg) 26 | 27 | ## Requirements 28 | 29 | Project's requirements can be installed in a separate conda environment by running the following command in your terminal: ```$ conda env create -f environment.yml```. 30 | 31 | ## Features 32 | 33 | The extracted features for each dataset can be downloaded using the following links: 34 | 35 | ### EPIC-KITCHENS-100: 36 | 37 | * [Train](https://www.dropbox.com/s/yb9jtzq24cd2hnl/audiovisual_slowfast_features_train.hdf5?dl=0) 38 | * [Val](https://www.dropbox.com/s/8yeb84ewd2meib8/audiovisual_slowfast_features_val.hdf5?dl=0) 39 | * [Test](https://www.dropbox.com/s/6vifpn3qurkyf96/audiovisual_slowfast_features_test.hdf5?dl=0) 40 | 41 | ### EGTEA: 42 | 43 | * [Train-split1](https://www.dropbox.com/s/6hr994w3kkvbtv0/visual_slowfast_features_train_split1.hdf5?dl=0) 44 | * [Test-split1](https://www.dropbox.com/s/03aa8hmflv7depe/visual_slowfast_features_test_split1.hdf5?dl=0) 45 | 46 | ## Pretrained models 47 | 48 | We provide pretrained models for EPIC-KITCHENS-100: 49 | 50 | * Audio-visual transformer [link](https://www.dropbox.com/s/vqe7esmqqwsebo6/mtcn_av_sf_epic-kitchens-100.pyth?dl=0) 51 | * Language model [link](https://www.dropbox.com/s/80lcnvsoq4y7tux/mtcn_lm_epic-kitchens-100.pyth?dl=0) 52 | 53 | ## Ground-truth 54 | 55 | * The ground-truth of EPIC-KITCHENS-100 can be found at [this repository](https://github.com/epic-kitchens/epic-kitchens-100-annotations) 56 | * The ground-truth of EGTEA, processed by us to be in a cleaner format, can be downloaded from the following links: [[Train-split1]](https://www.dropbox.com/s/8zxdsi13v7oy106/train_split1.pkl?dl=0) [[Test-split1]](https://www.dropbox.com/s/50bkljl71njyj46/test_split1.pkl?dl=0) [[Action mapping]](https://www.dropbox.com/s/cg0pagu2px0f6k0/actions_egtea.csv?dl=0) 57 | 58 | ## Train 59 | 60 | ### EPIC-KITCHENS-100 61 | To train the audio-visual transformer on EPIC-KITCHENS-100, run: 62 | 63 | ``` 64 | python train_av.py --dataset epic-100 --train_hdf5_path /path/to/epic-kitchens-100/features/audiovisual_slowfast_features_train.hdf5 65 | --val_hdf5_path /path/to/epic-kitchens-100/features/audiovisual_slowfast_features_val.hdf5 66 | --train_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_train.pkl 67 | --val_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_validation.pkl 68 | --batch-size 32 --lr 0.005 --optimizer sgd --epochs 100 --lr_steps 50 75 --output_dir /path/to/output_dir 69 | --num_layers 4 -j 8 --classification_mode all --seq_len 9 70 | ``` 71 | 72 | To train the language model on EPIC-KITCHENS-100, run: 73 | ``` 74 | python train_lm.py --dataset epic-100 --train_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_train.pkl 75 | --val_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_validation.pkl 76 | --verb_csv /path/to/epic-kitchens-100-annotations/EPIC_100_verb_classes.csv 77 | --noun_csv /path/to/epic-kitchens-100-annotations/EPIC_100_noun_classes.csv 78 | --batch-size 64 --lr 0.001 --optimizer adam --epochs 100 --lr_steps 50 75 --output_dir /path/to/output_dir 79 | --num_layers 4 -j 8 --num_gram 9 --dropout 0.1 80 | ``` 81 | 82 | ### EGTEA 83 | To train the visual-only transformer on EGTEA (EGTEA does not have audio), run: 84 | 85 | ``` 86 | python train_av.py --dataset egtea --train_hdf5_path /path/to/egtea/features/visual_slowfast_features_train_split1.hdf5 87 | --val_hdf5_path /path/to/egtea/features/visual_slowfast_features_test_split1.hdf5 88 | --train_pickle /path/to/EGTEA_annotations/train_split1.pkl --val_pickle /path/to/EGTEA_annotations/test_split1.pkl 89 | --batch-size 32 --lr 0.001 --optimizer sgd --epochs 50 --lr_steps 25 38 --output_dir /path/to/output_dir 90 | --num_layers 4 -j 8 --classification_mode all --seq_len 9 91 | ``` 92 | 93 | To train the language model on EGTEA, 94 | ``` 95 | python train_lm.py --dataset egtea --train_pickle /path/to/EGTEA_annotations/train_split1.pkl 96 | --val_pickle /path/to/EGTEA_annotations/test_split1.pkl 97 | --action_csv /path/to/EGTEA_annotations/actions_egtea.csv 98 | --batch-size 64 --lr 0.001 --optimizer adam --epochs 50 --lr_steps 25 38 --output_dir /path/to/output_dir 99 | --num_layers 4 -j 8 --num_gram 9 --dropout 0.1 100 | ``` 101 | 102 | ## Test 103 | 104 | ### EPIC-KITCHENS-100 105 | To test the audio-visual transformer on EPIC-KITCHENS-100, run: 106 | 107 | ``` 108 | python test_av.py --dataset epic-100 --test_hdf5_path /path/to/epic-kitchens-100/features/audiovisual_slowfast_features_val.hdf5 109 | --test_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_validation.pkl 110 | --checkpoint /path/to/av_model/av_checkpoint.pyth --seq_len 9 --num_layers 4 --output_dir /path/to/output_dir 111 | --split validation 112 | ``` 113 | 114 | To obtain scores of the model on the test set, simply use ```--test_hdf5_path /path/to/epic-kitchens-100/features/audiovisual_slowfast_features_test.hdf5```, 115 | ```--test_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_test_timestamps.pkl``` 116 | and ```--split test``` instead. Since the labels for the test set are not available the script will simply save the scores 117 | without computing the accuracy of the model. 118 | 119 | To evaluate your model on the validation set, follow the instructions in [this link](https://github.com/epic-kitchens/C1-Action-Recognition). 120 | In the same link, you can find instructions for preparing the scores of the model for submission in the evaluation server and obtain results 121 | on the test set. 122 | 123 | Finally, to filter out improbable sequences using LM, run: 124 | 125 | ``` 126 | python test_av_lm.py --dataset epic-100 127 | --test_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_validation.pkl 128 | --test_scores /path/to/audio-visual-results.pkl 129 | --checkpoint /path/to/lm_model/lm_checkpoint.pyth 130 | --num_gram 9 --split validation 131 | ``` 132 | Note that, ```--test_scores /path/to/audio-visual-results.pkl``` are the scores predicted from the audio-visual transformer. To obtain scores on the test set, use ```--test_pickle /path/to/epic-kitchens-100-annotations/EPIC_100_test_timestamps.pkl``` 133 | and ```--split test``` instead. 134 | 135 | Since we are providing the trained models for EPIC-KITCHENS-100, `av_checkpoint.pyth` and `lm_checkpoint.pyth` in the test scripts above could be either the provided pretrained models or `model_best.pyth` that is the your own trained model. 136 | 137 | ### EGTEA 138 | 139 | To test the visual-only transformer on EGTEA, run: 140 | 141 | ``` 142 | python test_av.py --dataset egtea --test_hdf5_path /path/to/egtea/features/visual_slowfast_features_test_split1.hdf5 143 | --test_pickle /path/to/EGTEA_annotations/test_split1.pkl 144 | --checkpoint /path/to/v_model/model_best.pyth --seq_len 9 --num_layers 4 --output_dir /path/to/output_dir 145 | --split test_split1 146 | ``` 147 | 148 | To filter out improbable sequences using LM, run: 149 | ``` 150 | python test_av_lm.py --dataset egtea 151 | --test_pickle /path/to/EGTEA_annotations/test_split1.pkl 152 | --test_scores /path/to/visual-results.pkl 153 | --checkpoint /path/to/lm_model/model_best.pyth 154 | --num_gram 9 --split test_split1 155 | ``` 156 | 157 | In each case, you can extract attention weights by simply including ```--extract_attn_weights``` at the input arguments of the test script. 158 | 159 | ## References 160 | [1] Dima Damen, Hazel Doughty, Giovanni Maria Farinella, , Antonino Furnari, Jian Ma,Evangelos Kazakos, Davide Moltisanti, Jonathan Munro, Toby Perrett, Will Price, andMichael Wray, **Rescaling Egocentric Vision: Collection Pipeline and Challenges for EPIC-KITCHENS-100**, IJCV, 2021 161 | 162 | ## License 163 | 164 | The code is published under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License, found [here](https://creativecommons.org/licenses/by-nc-sa/4.0/). 165 | -------------------------------------------------------------------------------- /test_av_lm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from sklearn.metrics import confusion_matrix, accuracy_score 7 | from scipy.special import log_softmax 8 | from collections import OrderedDict 9 | import pandas as pd 10 | import pickle 11 | 12 | from models_lm import MTCN_LM 13 | from utils import get_topk, get_topk_action, get_lmscore, get_lmscore_action 14 | 15 | _NUM_CLASSES = {'epic-55': [125, 352], 'epic-100': [97, 300], 'egtea': 106} 16 | 17 | 18 | def eval_epicvideos(av_results, df_labels, model, device, ntokens): 19 | # Beam search for verbs and nouns 20 | verb_scores = av_results['verb_output'].tolist() 21 | noun_scores = av_results['noun_output'].tolist() 22 | 23 | df_labels['verb_scores'] = verb_scores 24 | df_labels['noun_scores'] = noun_scores 25 | df_labels['verb_pred'] = "" 26 | df_labels['noun_pred'] = "" 27 | 28 | results = [] 29 | 30 | video_ids = sorted(list(set(df_labels['video_id']))) 31 | 32 | for video_num, video_id in enumerate(video_ids): 33 | print("Processing video [{}/{}] ....".format(video_num + 1, len(video_ids))) 34 | df_video = df_labels[df_labels['video_id'] == video_id] 35 | df_video = df_video.sort_values(by='start_frame') 36 | 37 | for ii in range(len(df_video)): 38 | row = df_video.iloc[ii] 39 | verb_score, noun_score = torch.FloatTensor(row['verb_scores']).unsqueeze_(0), torch.FloatTensor(row['noun_scores']).unsqueeze_(0) 40 | 41 | narration_id = row.name 42 | 43 | verb_score = F.log_softmax(verb_score, dim=-1) 44 | noun_score = F.log_softmax(noun_score, dim=-1) 45 | 46 | if ii < args.num_gram // 2 or ii >= (len(df_video) - args.num_gram // 2): 47 | # Use the audio-visual output for corner actions 48 | verb_pred = verb_score.cpu().numpy().reshape(-1) 49 | noun_pred = noun_score.cpu().numpy().reshape(-1) 50 | else: 51 | verb_sequence = log_softmax(np.array(list(df_video['verb_scores'][ii - args.num_gram // 2: ii + args.num_gram // 2 + 1])), axis=-1) 52 | noun_sequence = log_softmax(np.array(list(df_video['noun_scores'][ii - args.num_gram // 2: ii + args.num_gram // 2 + 1])), axis=-1) 53 | 54 | verb_candidates, noun_candidates = get_topk(verb_sequence, noun_sequence, args.beam_size) 55 | 56 | verb_pred = verb_score.cpu().numpy().reshape(-1) 57 | noun_pred = noun_score.cpu().numpy().reshape(-1) 58 | verb_dict = {} 59 | noun_dict = {} 60 | 61 | # Beam search 62 | for jj in range(args.beam_size): 63 | for kk in range(args.beam_size): 64 | verb_input, verb_avscore = verb_candidates[jj] 65 | noun_input, noun_avscore = noun_candidates[kk] 66 | verb_input = torch.LongTensor(verb_input).unsqueeze_(1).to(device) 67 | noun_input = torch.LongTensor(noun_input).unsqueeze_(1).to(device) 68 | verb_lmscore, noun_lmscore = get_lmscore(verb_input, noun_input, model, args.num_gram, ntokens) 69 | 70 | # LM fusion with hyperparameter alpha 71 | verb_score = (1 - args.alpha) * verb_avscore + args.alpha * verb_lmscore 72 | noun_score = (1 - args.alpha) * noun_avscore + args.alpha * noun_lmscore 73 | 74 | verb_center = verb_candidates[jj][0][args.num_gram // 2] 75 | noun_center = noun_candidates[kk][0][args.num_gram // 2] 76 | 77 | if verb_center not in verb_dict: 78 | verb_dict[verb_center] = verb_score 79 | if noun_center not in noun_dict: 80 | noun_dict[noun_center] = noun_score 81 | if verb_dict[verb_center] < verb_score: 82 | verb_dict[verb_center] = verb_score 83 | if noun_dict[noun_center] < noun_score: 84 | noun_dict[noun_center] = noun_score 85 | 86 | verb_dict = OrderedDict([(k,v) for k, v in sorted(verb_dict.items(), key=lambda item: item[1], reverse=False)]) 87 | noun_dict = OrderedDict([(k,v) for k, v in sorted(noun_dict.items(), key=lambda item: item[1], reverse=False)]) 88 | verb_max = np.max(verb_pred) 89 | noun_max = np.max(noun_pred) 90 | 91 | c = 0.1 92 | for jj, (key, item) in enumerate(verb_dict.items()): 93 | verb_pred[key] = verb_max + c * (jj + 1) 94 | for jj, (key, item) in enumerate(noun_dict.items()): 95 | noun_pred[key] = noun_max + c * (jj + 1) 96 | 97 | df_labels.at[narration_id, 'verb_pred'] = verb_pred 98 | df_labels.at[narration_id, 'noun_pred'] = noun_pred 99 | 100 | for ii in range(len(df_labels)): 101 | row = df_labels.iloc[ii] 102 | rst_ = {'verb': row['verb_pred'], 'noun' : row['noun_pred']} 103 | labels_ = {'verb' : row['verb_class'], 'noun' : row['noun_class']} if args.split != 'test' else {} 104 | narration_id = row.name 105 | results.append((rst_, labels_, narration_id)) 106 | 107 | return results 108 | 109 | 110 | def eval_egteavideos(av_results, df_labels, model, device, ntokens): 111 | # Beam search for actions 112 | action_scores = av_results['scores'].tolist() 113 | action_classes = av_results['labels'].tolist() 114 | 115 | df_labels['action_scores'] = action_scores 116 | df_labels['action_class'] = action_classes 117 | df_labels['action_pred'] = "" 118 | 119 | results = [] 120 | 121 | video_ids = sorted(list(set(df_labels['video_name']))) 122 | 123 | for video_num, video_id in enumerate(video_ids): 124 | print("Processing video [{}/{}] ....".format(video_num + 1, len(video_ids))) 125 | df_video = df_labels[df_labels['video_name'] == video_id] 126 | df_video = df_video.sort_values(by='start_frame') 127 | 128 | for ii in range(len(df_video)): 129 | row = df_video.iloc[ii] 130 | action_score = torch.FloatTensor(row['action_scores']).unsqueeze_(0) 131 | action_score = F.log_softmax(action_score, dim=-1) 132 | narration_id = row.name 133 | 134 | if ii < args.num_gram // 2 or ii >= (len(df_video) - args.num_gram // 2): 135 | # Use the audio-visual output for corner actions 136 | action_pred = action_score.cpu().numpy().reshape(-1) 137 | else: 138 | action_sequence = log_softmax(np.array(list(df_video['action_scores'][ii - args.num_gram // 2: ii + args.num_gram // 2 + 1])), axis=-1) 139 | action_candidates = get_topk_action(action_sequence, args.beam_size) 140 | 141 | action_pred = action_score.cpu().numpy().reshape(-1) 142 | action_dict = {} 143 | 144 | # Beam search 145 | for jj in range(args.beam_size): 146 | action_input, action_avscore = action_candidates[jj] 147 | action_input = torch.LongTensor(action_input).unsqueeze_(1).to(device) 148 | action_lmscore = get_lmscore_action(action_input, model, args.num_gram, ntokens) 149 | 150 | # LM fusion with hyperparameter alpha 151 | action_score = (1 - args.alpha) * action_avscore + args.alpha * action_lmscore 152 | action_center = action_candidates[jj][0][args.num_gram // 2] 153 | 154 | if action_center not in action_dict: 155 | action_dict[action_center] = action_score 156 | if action_dict[action_center] < action_score: 157 | action_dict[action_center] = action_score 158 | 159 | action_dict = OrderedDict([(k,v) for k, v in sorted(action_dict.items(), key=lambda item: item[1], reverse=False)]) 160 | action_max = np.max(action_pred) 161 | 162 | c = 0.1 163 | for jj, (key, item) in enumerate(action_dict.items()): 164 | action_pred[key] = action_max + c * (jj + 1) 165 | 166 | df_labels.at[narration_id, 'action_pred'] = action_pred 167 | 168 | for ii in range(len(df_labels)): 169 | row = df_labels.iloc[ii] 170 | rst_ = row['action_pred'] 171 | labels_ = row['action_class'] 172 | narration_id = row.name 173 | results.append((rst_, labels_, narration_id)) 174 | 175 | return results 176 | 177 | 178 | def print_accuracy(scores, labels): 179 | # Printing accuracy and average per-class accuracy 180 | video_pred = [np.argmax(score) for score in scores] 181 | cf = confusion_matrix(labels, video_pred).astype(float) 182 | cls_cnt = cf.sum(axis=1) 183 | cls_hit = np.diag(cf) 184 | cls_cnt[cls_hit == 0] = 1 # to avoid divisions by zero 185 | cls_acc = cls_hit / cls_cnt 186 | 187 | acc = accuracy_score(labels, video_pred) 188 | 189 | print('Accuracy {:.02f}%'.format(acc * 100)) 190 | print('Average Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 191 | 192 | 193 | def save_scores(results, output): 194 | # Save the scores as a pickle format 195 | save_dict = {} 196 | if not isinstance(_NUM_CLASSES[args.dataset], list): 197 | scores = np.array([result[0] for result in results]) 198 | labels = np.array([result[1] for result in results]) 199 | save_dict['scores'] = scores 200 | save_dict['labels'] = labels 201 | else: 202 | keys = results[0][0].keys() 203 | save_dict = {k + '_output': np.array([result[0][k] for result in results]) for k in keys} 204 | save_dict['narration_id'] = np.array([result[2] for result in results]) 205 | 206 | with open(output, 'wb') as f: 207 | pickle.dump(save_dict, f) 208 | 209 | 210 | def main(): 211 | parser = argparse.ArgumentParser(description=('Fuse the MTCN output scores and LM scores')) 212 | parser.add_argument('--test_pickle', type=Path) 213 | parser.add_argument('--test_scores', type=Path) 214 | parser.add_argument('--checkpoint', type=Path) 215 | parser.add_argument('--dataset', choices=['epic-55', 'epic-100', 'egtea']) 216 | parser.add_argument('--num_gram', default=9, type=int) 217 | parser.add_argument('--d_model', type=int, default=512) 218 | parser.add_argument('--dim_feedforward', type=int, default=512) 219 | parser.add_argument('--nhead', type=int, default=8) 220 | parser.add_argument('--num_layers', type=int, default=4) 221 | # ------------------------------ BEAM SEARCH ---------------------------------- 222 | parser.add_argument('--alpha', type=float, default=0.15) 223 | parser.add_argument('--beam_size', type=int, default=10) 224 | # ------------------------------ OUTPUT ---------------------------------- 225 | parser.add_argument('--output_dir', type=Path, default='scores') 226 | parser.add_argument('--split', type=str, default='result') 227 | 228 | global args 229 | args = parser.parse_args() 230 | 231 | np.random.seed(0) 232 | torch.manual_seed(0) 233 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 234 | 235 | ntokens = _NUM_CLASSES[args.dataset] 236 | 237 | # Load model 238 | model = MTCN_LM(ntokens, 239 | args.d_model, 240 | args.dim_feedforward, 241 | args.nhead, 242 | args.num_layers) 243 | 244 | model.load_state_dict(torch.load(args.checkpoint)['state_dict']) 245 | model = model.to(device) 246 | model.eval() 247 | 248 | # For beam search 249 | if args.alpha == 0: 250 | # You don't need a beam search for this 251 | args.beam_size = 1 252 | 253 | # Load the audio-visual output 254 | with open(args.test_scores, 'rb') as f: 255 | av_results = pickle.load(f) 256 | 257 | df_labels = pd.read_pickle(args.test_pickle) 258 | 259 | if args.dataset.split('-')[0] == 'epic': 260 | results = eval_epicvideos(av_results, df_labels, model, device, ntokens) 261 | else: 262 | results = eval_egteavideos(av_results, df_labels, model, device, ntokens) 263 | 264 | print("ALPHA : {}, BEAM_SIZE : {}".format(args.alpha, args.beam_size)) 265 | 266 | # Print accuracy 267 | if ('test' not in args.split and 'epic' in args.dataset) or 'epic' not in args.dataset: 268 | if isinstance(ntokens, list): 269 | keys = results[0][0].keys() 270 | for task in keys: 271 | print('Evaluation of {}'.format(task.upper())) 272 | print_accuracy([result[0][task] for result in results], 273 | [result[1][task] for result in results]) 274 | else: 275 | print_accuracy([result[0] for result in results], 276 | [result[1] for result in results]) 277 | 278 | # Save the scores file 279 | output_dir = args.output_dir / Path('scores') 280 | if not output_dir.exists(): 281 | output_dir.mkdir(parents=True) 282 | save_scores(results, output_dir / Path(args.split + '.pkl')) 283 | 284 | 285 | if __name__ == '__main__': 286 | main() -------------------------------------------------------------------------------- /train_lm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import wandb 4 | import numpy as np 5 | import pandas as pd 6 | import random 7 | from pathlib import Path 8 | 9 | from corpus import EpicCorpus, EgteaCorpus 10 | from models_lm import MTCN_LM 11 | from utils import accuracy, multitask_accuracy, save_checkpoint, AverageMeter 12 | 13 | import torch 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | import torch.nn.functional as F 16 | 17 | _NUM_CLASSES = {'epic-55': [125, 352], 'epic-100': [97, 300], 'egtea': 106} 18 | _CORPUS = {'epic': EpicCorpus, 'egtea': EgteaCorpus} 19 | 20 | parser = argparse.ArgumentParser(description=('Train language model from sequence of actions')) 21 | 22 | # ------------------------------ Dataset ------------------------------- 23 | parser.add_argument('--train_pickle', type=Path) 24 | parser.add_argument('--val_pickle', type=Path) 25 | parser.add_argument('--verb_csv', type=Path, help='verb csv file if epic') 26 | parser.add_argument('--noun_csv', type=Path, help='noun csv file if epic') 27 | parser.add_argument('--action_csv', type=Path, help='action csv file if egtea') 28 | parser.add_argument('--dataset', choices=['epic-55', 'epic-100', 'egtea']) 29 | # ------------------------------ Model --------------------------------- 30 | parser.add_argument('--num_gram', type=int, default=9) 31 | parser.add_argument('--d_model', type=int, default=512) 32 | parser.add_argument('--dim_feedforward', type=int, default=512) 33 | parser.add_argument('--nhead', type=int, default=8) 34 | parser.add_argument('--num_layers', type=int, default=4) 35 | parser.add_argument('--dropout', type=float, default=0.1) 36 | # ------------------------------ Train ---------------------------------- 37 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('-b', '--batch-size', default=128, type=int, 40 | metavar='N', help='mini-batch size (default: 256)') 41 | # ------------------------------ Optimizer ------------------------------ 42 | parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='adam') 43 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 44 | metavar='LR', help='initial learning rate') 45 | parser.add_argument('--lr_steps', default=[25, 37], type=float, nargs="+", 46 | metavar='LRSteps', help='epochs to decay learning rate by 10') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 50 | metavar='W', help='weight decay (default: 5e-4)') 51 | parser.add_argument('--clip-gradient', '--gd', default=5, type=float, 52 | metavar='W', help='gradient norm clipping') 53 | # ------------------------------ Misc ------------------------------------ 54 | parser.add_argument('--output_dir', type=Path) 55 | parser.add_argument('--disable_wandb_log', action='store_true') 56 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 57 | help='number of data loading workers (default: 4)') 58 | parser.add_argument('--print-freq', '-p', default=600, type=int, 59 | metavar='N', help='print frequency (default: 10)') 60 | 61 | args = parser.parse_args() 62 | 63 | best_prec1 = 0 64 | training_iterations = 0 65 | 66 | if not args.output_dir.exists(): 67 | args.output_dir.mkdir(parents=True) 68 | 69 | 70 | def main(): 71 | global args, best_prec1 72 | 73 | random.seed(0) 74 | np.random.seed(0) 75 | torch.manual_seed(0) 76 | 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | model = MTCN_LM(_NUM_CLASSES[args.dataset], 79 | args.d_model, 80 | args.dim_feedforward, 81 | args.nhead, 82 | args.num_layers, 83 | dropout=args.dropout) 84 | model = model.to(device) 85 | 86 | if not args.disable_wandb_log: 87 | wandb.init(project='MTCN', config=args) 88 | wandb.watch(model) 89 | 90 | if args.dataset.split('-')[0] == 'epic': 91 | csvfiles = [args.verb_csv, args.noun_csv] 92 | else: 93 | csvfiles = [args.action_csv] 94 | 95 | train_corpus = _CORPUS[args.dataset.split('-')[0]](args.train_pickle, csvfiles, _NUM_CLASSES[args.dataset], args.num_gram, train=True) 96 | val_corpus = _CORPUS[args.dataset.split('-')[0]](args.val_pickle, csvfiles, _NUM_CLASSES[args.dataset], args.num_gram, train=False) 97 | 98 | train_loader = torch.utils.data.DataLoader( 99 | train_corpus, 100 | batch_size=args.batch_size, 101 | shuffle=True, num_workers=args.workers, 102 | pin_memory=True) 103 | 104 | val_loader = torch.utils.data.DataLoader( 105 | val_corpus, 106 | batch_size=args.batch_size, 107 | shuffle=False, 108 | num_workers=1, 109 | pin_memory=False) 110 | 111 | criterion = torch.nn.NLLLoss() 112 | 113 | # Optimizer and scheduler 114 | if args.optimizer == 'sgd': 115 | optimizer = torch.optim.SGD(model.parameters(), 116 | lr=args.lr, 117 | weight_decay=args.weight_decay) 118 | else: 119 | optimizer = torch.optim.Adam(model.parameters(), 120 | lr=args.lr, 121 | weight_decay=args.weight_decay) 122 | 123 | scheduler = MultiStepLR(optimizer, args.lr_steps, gamma=0.1) 124 | 125 | # Training loop 126 | for epoch in range(1, args.epochs): 127 | train(train_loader, model, criterion, epoch, optimizer, device) 128 | # evaluate on validation set 129 | prec1 = validate(val_loader, model, criterion, device) 130 | # remember best prec@1 and save checkpoint 131 | is_best = prec1 > best_prec1 132 | best_prec1 = max(prec1, best_prec1) 133 | save_checkpoint({ 134 | 'epoch': epoch + 1, 135 | 'state_dict': model.state_dict(), 136 | 'best_prec1': best_prec1, 137 | }, is_best, args.output_dir) 138 | scheduler.step() 139 | 140 | 141 | def validate(val_loader, model, criterion, device, name=''): 142 | global training_iterations 143 | is_multitask = isinstance(model.num_class, list) 144 | ntokens = val_loader.dataset.num_class 145 | 146 | with torch.no_grad(): 147 | batch_time = AverageMeter() 148 | losses = AverageMeter() 149 | top1 = AverageMeter() 150 | top5 = AverageMeter() 151 | if is_multitask: 152 | verb_losses = AverageMeter() 153 | noun_losses = AverageMeter() 154 | verb_top1 = AverageMeter() 155 | verb_top5 = AverageMeter() 156 | noun_top1 = AverageMeter() 157 | noun_top5 = AverageMeter() 158 | 159 | # switch to evaluate mode 160 | model.eval() 161 | 162 | end = time.time() 163 | for batch, data in enumerate(val_loader): 164 | for key, item in data.items(): 165 | data[key] = torch.transpose(item.to(device), 0, 1) 166 | 167 | if not is_multitask: 168 | output = model(data['input']) 169 | output = output.view(-1, ntokens) 170 | batch_size = output.size(0) 171 | output = F.log_softmax(output, dim=-1) 172 | 173 | loss = criterion(output, data['target'].reshape(-1)) 174 | 175 | # Evaluate accuracies - Calculate accuracy only for masked positions 176 | output = output[data['input'].reshape(-1) == ntokens] 177 | target = data['target'][data['input'] == ntokens] 178 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 179 | else: 180 | output = model(data['verb_input'], data['noun_input']) 181 | output = output.view(-1, ntokens[0] + ntokens[1]) 182 | batch_size = output.size(0) 183 | verb_output = F.log_softmax(output[..., :ntokens[0]], dim=-1) 184 | noun_output = F.log_softmax(output[..., ntokens[0]:], dim=-1) 185 | 186 | loss_verb = criterion(verb_output, data['verb_target'].reshape(-1)) 187 | loss_noun = criterion(noun_output, data['noun_target'].reshape(-1)) 188 | loss = 0.5 * (loss_verb + loss_noun) 189 | verb_losses.update(loss_verb.item(), batch_size) 190 | noun_losses.update(loss_noun.item(), batch_size) 191 | 192 | # Evaluate accuracies - Calculate accuracy only for masked positions 193 | verb_output = verb_output[data['verb_input'].reshape(-1) == ntokens[0]] 194 | noun_output = noun_output[data['noun_input'].reshape(-1) == ntokens[1]] 195 | verb_target = data['verb_target'][data['verb_input'] == ntokens[0]] 196 | noun_target = data['noun_target'][data['noun_input'] == ntokens[1]] 197 | 198 | verb_prec1, verb_prec5 = accuracy(verb_output, verb_target, topk=(1, 5)) 199 | verb_top1.update(verb_prec1, batch_size) 200 | verb_top5.update(verb_prec5, batch_size) 201 | 202 | noun_prec1, noun_prec5 = accuracy(noun_output, noun_target, topk=(1, 5)) 203 | noun_top1.update(noun_prec1, batch_size) 204 | noun_top5.update(noun_prec5, batch_size) 205 | 206 | prec1, prec5 = multitask_accuracy((verb_output, noun_output), 207 | (verb_target, noun_target), 208 | topk=(1, 5)) 209 | 210 | losses.update(loss.item(), batch_size) 211 | top1.update(prec1, batch_size) 212 | top5.update(prec5, batch_size) 213 | 214 | # measure elapsed time 215 | batch_time.update(time.time() - end) 216 | end = time.time() 217 | 218 | # Logging 219 | if not is_multitask: 220 | if not args.disable_wandb_log: 221 | wandb.log( 222 | { 223 | "Val/loss": losses.avg, 224 | "Val/Top1_acc": top1.avg, 225 | "Val/Top5_acc": top5.avg, 226 | "val_step": training_iterations, 227 | }, 228 | ) 229 | 230 | message = ('Testing Results: ' 231 | 'Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} ' 232 | 'Loss {loss.avg:.5f}').format(top1=top1, 233 | top5=top5, 234 | loss=losses) 235 | else: 236 | if not args.disable_wandb_log: 237 | wandb.log( 238 | { 239 | "Val/loss": losses.avg, 240 | "Val/Top1_acc": top1.avg, 241 | "Val/Top5_acc": top5.avg, 242 | "Val/verb/loss": verb_losses.avg, 243 | "Val/verb/Top1_acc": verb_top1.avg, 244 | "Val/verb/Top5_acc": verb_top5.avg, 245 | "Val/noun/loss": noun_losses.avg, 246 | "Val/noun/Top1_acc": noun_top1.avg, 247 | "Val/noun/Top5_acc": noun_top5.avg, 248 | "val_step": training_iterations, 249 | }, 250 | ) 251 | 252 | message = ("Testing Results: " 253 | "{name} Verb Prec@1 {verb_top1.avg:.3f} Verb Prec@5 {verb_top5.avg:.3f} " 254 | "{name} Noun Prec@1 {noun_top1.avg:.3f} Noun Prec@5 {noun_top5.avg:.3f} " 255 | "{name} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} " 256 | "{name} Verb Loss {verb_loss.avg:.5f} " 257 | "{name} Noun Loss {noun_loss.avg:.5f} " 258 | "{name} Loss {loss.avg:.5f}").format(verb_top1=verb_top1, verb_top5=verb_top5, 259 | noun_top1=noun_top1, noun_top5=noun_top5, 260 | top1=top1, top5=top5, 261 | name=name, 262 | verb_loss=verb_losses, 263 | noun_loss=noun_losses, 264 | loss=losses) 265 | print(message) 266 | 267 | return top1.avg 268 | 269 | 270 | def train(train_loader, model, criterion, epoch, optimizer, device): 271 | global training_iterations 272 | is_multitask = isinstance(model.num_class, list) 273 | batch_time = AverageMeter() 274 | data_time = AverageMeter() 275 | losses = AverageMeter() 276 | if is_multitask: 277 | verb_losses = AverageMeter() 278 | noun_losses = AverageMeter() 279 | 280 | # switch to train mode 281 | model.train() 282 | 283 | end = time.time() 284 | 285 | ntokens = train_loader.dataset.num_class 286 | 287 | for i, data in enumerate(train_loader): 288 | for key, item in data.items(): 289 | data[key] = torch.transpose(item.to(device), 0, 1) 290 | batch_size = data[key].size(0) 291 | 292 | ## Scheduled sampling - uncomment this if you want to use it 293 | data = scheduled_sampling(model, data, device, ntokens, p=0.2) if args.dataset == 'epic' else data 294 | 295 | if not is_multitask: 296 | output = model(data['input']) 297 | output = output.view(-1, ntokens) 298 | batch_size = output.size(0) 299 | 300 | output = F.log_softmax(output, dim=-1) 301 | 302 | loss= criterion(output, data['target'].reshape(-1)) 303 | else: 304 | output = model(data['verb_input'], data['noun_input']) 305 | output = output.view(-1, ntokens[0] + ntokens[1]) 306 | batch_size = output.size(0) 307 | 308 | verb_output = F.log_softmax(output[..., :ntokens[0]], dim=-1) 309 | noun_output = F.log_softmax(output[..., ntokens[0]:], dim=-1) 310 | 311 | loss_verb = criterion(verb_output, data['verb_target'].reshape(-1)) 312 | loss_noun = criterion(noun_output, data['noun_target'].reshape(-1)) 313 | loss = 0.5 * (loss_verb + loss_noun) 314 | verb_losses.update(loss_verb.item(), batch_size) 315 | noun_losses.update(loss_noun.item(), batch_size) 316 | losses.update(loss.item(), batch_size) 317 | 318 | # Compute gradient and do SGD step 319 | optimizer.zero_grad() 320 | loss.backward() 321 | 322 | optimizer.step() 323 | 324 | training_iterations += 1 325 | 326 | # measure elapsed time 327 | batch_time.update(time.time() - end) 328 | end = time.time() 329 | 330 | # Logging 331 | if i % args.print_freq == 0: 332 | if not is_multitask: 333 | if not args.disable_wandb_log: 334 | wandb.log( 335 | { 336 | "Train/loss": losses.avg, 337 | "Train/epochs": epoch, 338 | "Train/lr": optimizer.param_groups[-1]['lr'], 339 | "train_step": training_iterations, 340 | }, 341 | ) 342 | 343 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' + 344 | 'Time {batch_time.avg:.3f} ({batch_time.avg:.3f})\t' + 345 | 'Data {data_time.avg:.3f} ({data_time.avg:.3f})\t' + 346 | 'Loss {loss.avg:.4f} ({loss.avg:.4f})\t' 347 | ).format( 348 | epoch, i, len(train_loader), batch_time=batch_time, 349 | data_time=data_time, loss=losses, 350 | lr=optimizer.param_groups[-1]['lr']) 351 | else: 352 | if not args.disable_wandb_log: 353 | wandb.log( 354 | { 355 | "Train/loss": losses.avg, 356 | "Train/epochs": epoch, 357 | "Train/lr": optimizer.param_groups[-1]['lr'], 358 | "Train/verb/loss": verb_losses.avg, 359 | "Train/noun/loss": noun_losses.avg, 360 | "train_step": training_iterations, 361 | }, 362 | ) 363 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' + 364 | 'Time {batch_time.avg:.3f} ({batch_time.avg:.3f})\t' + 365 | 'Data {data_time.avg:.3f} ({data_time.avg:.3f})\t' + 366 | 'Loss {loss.avg:.4f} ({loss.avg:.4f})\t' + 367 | 'Verb Loss {verb_loss.avg:.4f} ({verb_loss.avg:.4f})\t' + 368 | 'Noun Loss {noun_loss.avg:.4f} ({noun_loss.avg:.4f})\t' # + 369 | ).format( 370 | epoch, i, len(train_loader), batch_time=batch_time, 371 | data_time=data_time, loss=losses, verb_loss=verb_losses, 372 | noun_loss=noun_losses, 373 | lr=optimizer.param_groups[-1]['lr']) 374 | 375 | print(message) 376 | 377 | 378 | def scheduled_sampling(model, data, device, ntokens, p=0.2): 379 | # This functino returns the scheduled sampling output with a certain probability p 380 | if random.uniform(0,1) < p: 381 | randomlist = torch.LongTensor([np.random.randint(0, args.num_gram - 1, size=2) for p in range(0, batch_size)]).to(device) 382 | temp_verbinput = torch.clone(data['verb_target']) 383 | temp_nouninput = torch.clone(data['noun_target']) 384 | batch_size = data['verb_target'].size(0) 385 | for ii in range(batch_size): 386 | temp_verbinput[randomlist[ii], ii] = ntokens[0] 387 | temp_nouninput[randomlist[ii], ii] = ntokens[1] 388 | 389 | with torch.no_grad(): 390 | output_temp = model(temp_verbinput, temp_nouninput) 391 | 392 | verb_temp, noun_temp = [], [] 393 | for ii in range(batch_size): 394 | verb_temp.append(torch.max(output_temp[randomlist[ii], ii, :ntokens[0]], dim=-1)[1]) 395 | noun_temp.append(torch.max(output_temp[randomlist[ii], ii, ntokens[0]:], dim=-1)[1]) 396 | 397 | for ii in range(batch_size): 398 | data['verb_input'][randomlist[ii], ii] = verb_temp[ii] 399 | data['noun_input'][randomlist[ii], ii] = noun_temp[ii] 400 | 401 | return data 402 | 403 | 404 | if __name__ == '__main__': 405 | main() 406 | -------------------------------------------------------------------------------- /train_av.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import time 4 | import wandb 5 | import torch 6 | from torch.optim.lr_scheduler import MultiStepLR 7 | import numpy as np 8 | from models_av import MTCN_AV 9 | from epic_kitchens import EpicKitchens 10 | from egtea import Egtea 11 | from mixup import mixup_data, mixup_criterion 12 | from utils import accuracy, multitask_accuracy, save_checkpoint, AverageMeter 13 | 14 | 15 | _DATASETS = {'epic': EpicKitchens, 'egtea': Egtea} 16 | _NUM_CLASSES = {'epic-55': [125, 352], 'epic-100': [97, 300], 'egtea': 106} 17 | 18 | parser = argparse.ArgumentParser(description=('Train Audio-Visual Transformer on Sequence ' + 19 | 'of actions from untrimmed video')) 20 | 21 | # ------------------------------ Dataset ------------------------------- 22 | parser.add_argument('--train_hdf5_path', type=Path) 23 | parser.add_argument('--val_hdf5_path', type=Path) 24 | parser.add_argument('--train_pickle', type=Path) 25 | parser.add_argument('--val_pickle', type=Path) 26 | parser.add_argument('--dataset', choices=['epic-55', 'epic-100', 'egtea']) 27 | # ------------------------------ Model --------------------------------- 28 | parser.add_argument('--seq_len', type=int, default=5) 29 | parser.add_argument('--visual_input_dim', type=int, default=2304) 30 | parser.add_argument('--audio_input_dim', type=int, default=2304) 31 | parser.add_argument('--d_model', type=int, default=512) 32 | parser.add_argument('--dim_feedforward', type=int, default=2048) 33 | parser.add_argument('--nhead', type=int, default=8) 34 | parser.add_argument('--num_layers', type=int, default=6) 35 | parser.add_argument('--classification_mode', choices=['summary', 'all'], default='summary') 36 | parser.add_argument('--dropout', type=float, default=0.1) 37 | # ------------------------------ Train ---------------------------------- 38 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 39 | help='number of total epochs to run') 40 | parser.add_argument('-b', '--batch-size', default=128, type=int, 41 | metavar='N', help='mini-batch size (default: 256)') 42 | # ------------------------------ Optimizer ------------------------------ 43 | parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='adam') 44 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 45 | metavar='LR', help='initial learning rate') 46 | parser.add_argument('--lr_steps', default=[25, 40], type=float, nargs="+", 47 | metavar='LRSteps', help='epochs to decay learning rate by 10') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 51 | metavar='W', help='weight decay (default: 5e-4)') 52 | # ------------------------------ Misc ------------------------------------ 53 | parser.add_argument('--output_dir', type=Path) 54 | parser.add_argument('--disable_wandb_log', action='store_true') 55 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 56 | help='number of data loading workers (default: 4)') 57 | parser.add_argument('--print-freq', '-p', default=20, type=int, 58 | metavar='N', help='print frequency (default: 10)') 59 | 60 | args = parser.parse_args() 61 | 62 | best_prec1 = 0 63 | training_iterations = 0 64 | 65 | if not args.output_dir.exists(): 66 | args.output_dir.mkdir(parents=True) 67 | 68 | 69 | def main(): 70 | global args, best_prec1 71 | 72 | np.random.seed(0) 73 | torch.manual_seed(0) 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu", 0) 76 | model = MTCN_AV(_NUM_CLASSES[args.dataset], 77 | seq_len=args.seq_len, 78 | num_clips=1, 79 | visual_input_dim=args.visual_input_dim, 80 | audio_input_dim=args.audio_input_dim if args.dataset.split('-')[0] == 'epic' else None, 81 | d_model=args.d_model, 82 | dim_feedforward=args.dim_feedforward, 83 | nhead=args.nhead, 84 | num_layers=args.num_layers, 85 | dropout=args.dropout, 86 | classification_mode=args.classification_mode, 87 | audio=not args.dataset == 'egtea') 88 | model = model.to(device) 89 | 90 | if not args.disable_wandb_log: 91 | wandb.init(project='MTCN', config=args) 92 | wandb.watch(model) 93 | 94 | dataset = _DATASETS[args.dataset.split('-')[0]] 95 | train_loader = torch.utils.data.DataLoader( 96 | dataset(args.train_hdf5_path, 97 | args.train_pickle, 98 | visual_feature_dim=args.visual_input_dim, 99 | audio_feature_dim=args.audio_input_dim if args.dataset.split('-')[0] == 'epic' else None, 100 | window_len=args.seq_len, 101 | num_clips=10, 102 | clips_mode='random', 103 | labels_mode='all' if args.classification_mode == 'all' else 'center_action',), 104 | batch_size=args.batch_size, shuffle=True, 105 | num_workers=args.workers, pin_memory=True) 106 | 107 | val_loader = torch.utils.data.DataLoader( 108 | dataset(args.val_hdf5_path, 109 | args.val_pickle, 110 | visual_feature_dim=args.visual_input_dim, 111 | audio_feature_dim=args.audio_input_dim if args.dataset.split('-')[0] == 'epic' else None, 112 | window_len=args.seq_len, 113 | num_clips=10, 114 | clips_mode='random', 115 | labels_mode='all' if args.classification_mode == 'all' else 'center_action',), 116 | batch_size=args.batch_size, shuffle=False, 117 | num_workers=args.workers, pin_memory=True) 118 | 119 | if not args.classification_mode == 'all': 120 | criterion = torch.nn.CrossEntropyLoss() 121 | else: 122 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 123 | 124 | if args.optimizer == 'sgd': 125 | optimizer = torch.optim.SGD(model.parameters(), 126 | args.lr, 127 | momentum=args.momentum, 128 | weight_decay=args.weight_decay, 129 | nesterov=True) 130 | else: 131 | optimizer = torch.optim.Adam(model.parameters(), 132 | lr=args.lr, 133 | weight_decay=args.weight_decay) 134 | scheduler = MultiStepLR(optimizer, args.lr_steps, gamma=0.1) 135 | 136 | for epoch in range(args.epochs): 137 | train(train_loader, model, criterion, optimizer, epoch, device) 138 | # evaluate on validation set 139 | prec1 = validate(val_loader, model, criterion, device) 140 | # remember best prec@1 and save checkpoint 141 | is_best = prec1 > best_prec1 142 | best_prec1 = max(prec1, best_prec1) 143 | save_checkpoint({ 144 | 'epoch': epoch + 1, 145 | 'state_dict': model.state_dict(), 146 | 'best_prec1': best_prec1, 147 | }, is_best, args.output_dir) 148 | scheduler.step() 149 | 150 | 151 | def train(train_loader, model, criterion, optimizer, epoch, device): 152 | global training_iterations 153 | is_multitask = isinstance(model.num_class, list) 154 | batch_time = AverageMeter() 155 | data_time = AverageMeter() 156 | losses = AverageMeter() 157 | if is_multitask: 158 | verb_losses = AverageMeter() 159 | noun_losses = AverageMeter() 160 | if args.classification_mode == 'all': 161 | if 'epic' in args.dataset: 162 | weights = torch.tensor(2 * args.seq_len * [0.1] + [0.9]).unsqueeze(0).cuda(device=0) 163 | else: 164 | weights = torch.tensor(args.seq_len * [0.1] + [0.9]).unsqueeze(0).cuda(device=0) 165 | else: 166 | weights = None 167 | 168 | # switch to train mode 169 | model.train() 170 | 171 | end = time.time() 172 | 173 | for i, (input, target, _) in enumerate(train_loader): 174 | # measure data loading time 175 | data_time.update(time.time() - end) 176 | 177 | input = input.to(device) 178 | input, target_a, target_b, lam = mixup_data(input, target, alpha=0.2) 179 | # compute output 180 | output = model(input) 181 | batch_size = input.size(0) 182 | if not is_multitask: 183 | target_a = target_a.to(device) 184 | target_b = target_b.to(device) 185 | loss = mixup_criterion(criterion, output, target_a, target_b, lam, weights=weights) 186 | else: 187 | target_a = {k: v.to(device) for k, v in target_a.items()} 188 | target_b = {k: v.to(device) for k, v in target_b.items()} 189 | loss_verb = mixup_criterion(criterion, output[0], target_a['verb'], target_b['verb'], lam, weights=weights) 190 | loss_noun = mixup_criterion(criterion, output[1], target_a['noun'], target_b['noun'], lam, weights=weights) 191 | loss = 0.5 * (loss_verb + loss_noun) 192 | verb_losses.update(loss_verb.item(), batch_size) 193 | noun_losses.update(loss_noun.item(), batch_size) 194 | losses.update(loss.item(), batch_size) 195 | # compute gradient and do SGD step 196 | optimizer.zero_grad() 197 | 198 | loss.backward() 199 | 200 | optimizer.step() 201 | 202 | training_iterations += 1 203 | 204 | # measure elapsed time 205 | batch_time.update(time.time() - end) 206 | end = time.time() 207 | 208 | if i % args.print_freq == 0: 209 | if not is_multitask: 210 | if not args.disable_wandb_log: 211 | wandb.log( 212 | { 213 | "Train/loss": losses.avg, 214 | "Train/epochs": epoch, 215 | "Train/lr": optimizer.param_groups[-1]['lr'], 216 | "train_step": training_iterations, 217 | }, 218 | ) 219 | 220 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' + 221 | 'Time {batch_time.avg:.3f} ({batch_time.avg:.3f})\t' + 222 | 'Data {data_time.avg:.3f} ({data_time.avg:.3f})\t' + 223 | 'Loss {loss.avg:.4f} ({loss.avg:.4f})\t' 224 | ).format( 225 | epoch, i, len(train_loader), batch_time=batch_time, 226 | data_time=data_time, loss=losses, 227 | lr=optimizer.param_groups[-1]['lr']) 228 | else: 229 | if not args.disable_wandb_log: 230 | wandb.log( 231 | { 232 | "Train/loss": losses.avg, 233 | "Train/epochs": epoch, 234 | "Train/lr": optimizer.param_groups[-1]['lr'], 235 | "Train/verb/loss": verb_losses.avg, 236 | "Train/noun/loss": noun_losses.avg, 237 | "train_step": training_iterations, 238 | }, 239 | ) 240 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' + 241 | 'Time {batch_time.avg:.3f} ({batch_time.avg:.3f})\t' + 242 | 'Data {data_time.avg:.3f} ({data_time.avg:.3f})\t' + 243 | 'Loss {loss.avg:.4f} ({loss.avg:.4f})\t' + 244 | 'Verb Loss {verb_loss.avg:.4f} ({verb_loss.avg:.4f})\t' + 245 | 'Noun Loss {noun_loss.avg:.4f} ({noun_loss.avg:.4f})\t' # + 246 | ).format( 247 | epoch, i, len(train_loader), batch_time=batch_time, 248 | data_time=data_time, loss=losses, verb_loss=verb_losses, 249 | noun_loss=noun_losses, 250 | lr=optimizer.param_groups[-1]['lr']) 251 | 252 | print(message) 253 | 254 | 255 | def validate(val_loader, model, criterion, device, name=''): 256 | global training_iterations 257 | is_multitask = isinstance(model.num_class, list) 258 | with torch.no_grad(): 259 | batch_time = AverageMeter() 260 | losses = AverageMeter() 261 | top1 = AverageMeter() 262 | top5 = AverageMeter() 263 | if is_multitask: 264 | verb_losses = AverageMeter() 265 | noun_losses = AverageMeter() 266 | verb_top1 = AverageMeter() 267 | verb_top5 = AverageMeter() 268 | noun_top1 = AverageMeter() 269 | noun_top5 = AverageMeter() 270 | if args.classification_mode == 'all': 271 | if 'epic' in args.dataset: 272 | weights = torch.tensor(2 * args.seq_len * [0.1] + [0.9]).unsqueeze(0).cuda(device=0) 273 | else: 274 | weights = torch.tensor(args.seq_len * [0.1] + [0.9]).unsqueeze(0).cuda(device=0) 275 | else: 276 | weights = None 277 | # switch to evaluate mode 278 | model.eval() 279 | 280 | end = time.time() 281 | for i, (input, target, _) in enumerate(val_loader): 282 | 283 | input = input.to(device) 284 | 285 | # compute output 286 | output = model(input) 287 | batch_size = input.size(0) 288 | if not is_multitask: 289 | target = target.to(device) 290 | loss = criterion(output, target) 291 | if weights is not None: 292 | loss = loss * weights 293 | loss = loss.sum(1) 294 | loss = loss.mean() 295 | output = output if len(output.shape) == 2 else output[:, :, -1] 296 | target = target if len(target.shape) == 1 else target[:, -1] 297 | # measure accuracy and record loss 298 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 299 | else: 300 | target = {k: v.to(device) for k, v in target.items()} 301 | loss_verb = criterion(output[0], target['verb']) 302 | if weights is not None: 303 | loss_verb = loss_verb * weights 304 | loss_verb = loss_verb.sum(1) 305 | loss_verb = loss_verb.mean() 306 | loss_noun = criterion(output[1], target['noun']) 307 | if weights is not None: 308 | loss_noun = loss_noun * weights 309 | loss_noun = loss_noun.sum(1) 310 | loss_noun = loss_noun.mean() 311 | loss = 0.5 * (loss_verb + loss_noun) 312 | verb_losses.update(loss_verb.item(), batch_size) 313 | noun_losses.update(loss_noun.item(), batch_size) 314 | 315 | verb_output = output[0] if len(output[0].shape) == 2 else output[0][:, :, -1] 316 | noun_output = output[1] if len(output[1].shape) == 2 else output[1][:, :, -1] 317 | verb_target = target['verb'] if len(target['verb'].shape) == 1 else target['verb'][:, -1] 318 | noun_target = target['noun'] if len(target['noun'].shape) == 1 else target['noun'][:, -1] 319 | verb_prec1, verb_prec5 = accuracy(verb_output, verb_target, topk=(1, 5)) 320 | verb_top1.update(verb_prec1, batch_size) 321 | verb_top5.update(verb_prec5, batch_size) 322 | 323 | noun_prec1, noun_prec5 = accuracy(noun_output, noun_target, topk=(1, 5)) 324 | noun_top1.update(noun_prec1, batch_size) 325 | noun_top5.update(noun_prec5, batch_size) 326 | 327 | prec1, prec5 = multitask_accuracy((verb_output, noun_output), 328 | (verb_target, noun_target), 329 | topk=(1, 5)) 330 | 331 | losses.update(loss.item(), batch_size) 332 | top1.update(prec1, batch_size) 333 | top5.update(prec5, batch_size) 334 | 335 | # measure elapsed time 336 | batch_time.update(time.time() - end) 337 | end = time.time() 338 | 339 | if not is_multitask: 340 | if not args.disable_wandb_log: 341 | wandb.log( 342 | { 343 | "Val/loss": losses.avg, 344 | "Val/Top1_acc": top1.avg, 345 | "Val/Top5_acc": top5.avg, 346 | "val_step": training_iterations, 347 | }, 348 | ) 349 | 350 | message = ('Testing Results: ' 351 | 'Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} ' 352 | 'Loss {loss.avg:.5f}').format(top1=top1, 353 | top5=top5, 354 | loss=losses) 355 | else: 356 | if not args.disable_wandb_log: 357 | wandb.log( 358 | { 359 | "Val/loss": losses.avg, 360 | "Val/Top1_acc": top1.avg, 361 | "Val/Top5_acc": top5.avg, 362 | "Val/verb/loss": verb_losses.avg, 363 | "Val/verb/Top1_acc": verb_top1.avg, 364 | "Val/verb/Top5_acc": verb_top5.avg, 365 | "Val/noun/loss": noun_losses.avg, 366 | "Val/noun/Top1_acc": noun_top1.avg, 367 | "Val/noun/Top5_acc": noun_top5.avg, 368 | "val_step": training_iterations, 369 | }, 370 | ) 371 | 372 | message = ("Testing Results: " 373 | "{name} Verb Prec@1 {verb_top1.avg:.3f} Verb Prec@5 {verb_top5.avg:.3f} " 374 | "{name} Noun Prec@1 {noun_top1.avg:.3f} Noun Prec@5 {noun_top5.avg:.3f} " 375 | "{name} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} " 376 | "{name} Verb Loss {verb_loss.avg:.5f} " 377 | "{name} Noun Loss {noun_loss.avg:.5f} " 378 | "{name} Loss {loss.avg:.5f}").format(verb_top1=verb_top1, verb_top5=verb_top5, 379 | noun_top1=noun_top1, noun_top5=noun_top5, 380 | top1=top1, top5=top5, 381 | name=name, 382 | verb_loss=verb_losses, 383 | noun_loss=noun_losses, 384 | loss=losses) 385 | print(message) 386 | 387 | 388 | return top1.avg 389 | 390 | 391 | if __name__ == '__main__': 392 | main() 393 | -------------------------------------------------------------------------------- /transformers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, Any 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | from torch.nn import Module 8 | from torch.nn import MultiheadAttention 9 | from torch.nn import ModuleList 10 | from torch.nn.init import xavier_uniform_ 11 | from torch.nn import Dropout 12 | from torch.nn import Linear 13 | from torch.nn import LayerNorm 14 | 15 | 16 | class Transformer(Module): 17 | r"""A transformer model. User is able to modify the attributes as needed. The architecture 18 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 19 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 20 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 21 | Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) 22 | model with corresponding parameters. 23 | 24 | Args: 25 | d_model: the number of expected features in the encoder/decoder inputs (default=512). 26 | nhead: the number of heads in the multiheadattention models (default=8). 27 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 28 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 29 | dim_feedforward: the dimension of the feedforward network model (default=2048). 30 | dropout: the dropout value (default=0.1). 31 | activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). 32 | custom_encoder: custom encoder (default=None). 33 | custom_decoder: custom decoder (default=None). 34 | 35 | Examples:: 36 | >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) 37 | >>> src = torch.rand((10, 32, 512)) 38 | >>> tgt = torch.rand((20, 32, 512)) 39 | >>> out = transformer_model(src, tgt) 40 | 41 | Note: A full example to apply nn.Transformer module for the word language model is available in 42 | https://github.com/pytorch/examples/tree/master/word_language_model 43 | """ 44 | 45 | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, 46 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 47 | activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None) -> None: 48 | super(Transformer, self).__init__() 49 | 50 | if custom_encoder is not None: 51 | self.encoder = custom_encoder 52 | else: 53 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 54 | encoder_norm = LayerNorm(d_model) 55 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 56 | 57 | if custom_decoder is not None: 58 | self.decoder = custom_decoder 59 | else: 60 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 61 | decoder_norm = LayerNorm(d_model) 62 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 63 | 64 | self._reset_parameters() 65 | 66 | self.d_model = d_model 67 | self.nhead = nhead 68 | 69 | def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, 70 | memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, 71 | tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 72 | r"""Take in and process masked source/target sequences. 73 | 74 | Args: 75 | src: the sequence to the encoder (required). 76 | tgt: the sequence to the decoder (required). 77 | src_mask: the additive mask for the src sequence (optional). 78 | tgt_mask: the additive mask for the tgt sequence (optional). 79 | memory_mask: the additive mask for the encoder output (optional). 80 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 81 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 82 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 83 | 84 | Shape: 85 | - src: :math:`(S, N, E)`. 86 | - tgt: :math:`(T, N, E)`. 87 | - src_mask: :math:`(S, S)`. 88 | - tgt_mask: :math:`(T, T)`. 89 | - memory_mask: :math:`(T, S)`. 90 | - src_key_padding_mask: :math:`(N, S)`. 91 | - tgt_key_padding_mask: :math:`(N, T)`. 92 | - memory_key_padding_mask: :math:`(N, S)`. 93 | 94 | Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked 95 | positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 96 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 97 | are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 98 | is provided, it will be added to the attention weight. 99 | [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by 100 | the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero 101 | positions will be unchanged. If a BoolTensor is provided, the positions with the 102 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 103 | 104 | - output: :math:`(T, N, E)`. 105 | 106 | Note: Due to the multi-head attention architecture in the transformer model, 107 | the output sequence length of a transformer is same as the input sequence 108 | (i.e. target) length of the decode. 109 | 110 | where S is the source sequence length, T is the target sequence length, N is the 111 | batch size, E is the feature number 112 | 113 | Examples: 114 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 115 | """ 116 | 117 | if src.size(1) != tgt.size(1): 118 | raise RuntimeError("the batch number of src and tgt must be equal") 119 | 120 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model: 121 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 122 | 123 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) 124 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 125 | tgt_key_padding_mask=tgt_key_padding_mask, 126 | memory_key_padding_mask=memory_key_padding_mask) 127 | return output 128 | 129 | def generate_square_subsequent_mask(self, sz: int) -> Tensor: 130 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 131 | Unmasked positions are filled with float(0.0). 132 | """ 133 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 134 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 135 | return mask 136 | 137 | def _reset_parameters(self): 138 | r"""Initiate parameters in the transformer model.""" 139 | 140 | for p in self.parameters(): 141 | if p.dim() > 1: 142 | xavier_uniform_(p) 143 | 144 | 145 | class TransformerEncoder(Module): 146 | r"""TransformerEncoder is a stack of N encoder layers 147 | 148 | Args: 149 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 150 | num_layers: the number of sub-encoder-layers in the encoder (required). 151 | norm: the layer normalization component (optional). 152 | 153 | Examples:: 154 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 155 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 156 | >>> src = torch.rand(10, 32, 512) 157 | >>> out = transformer_encoder(src) 158 | """ 159 | __constants__ = ['norm'] 160 | 161 | def __init__(self, encoder_layer, num_layers, norm=None): 162 | super(TransformerEncoder, self).__init__() 163 | # self.layers = _get_clones(encoder_layer, num_layers) 164 | self.layer = encoder_layer 165 | self.num_layers = num_layers 166 | self.norm = norm 167 | 168 | def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 169 | r"""Pass the input through the encoder layers in turn. 170 | 171 | Args: 172 | src: the sequence to the encoder (required). 173 | mask: the mask for the src sequence (optional). 174 | src_key_padding_mask: the mask for the src keys per batch (optional). 175 | 176 | Shape: 177 | see the docs in Transformer class. 178 | """ 179 | output = src 180 | 181 | # for mod in self.layers: 182 | for _ in range(self.num_layers): 183 | # output, attn_weights = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 184 | output, attn_weights = self.layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 185 | 186 | if self.norm is not None: 187 | output = self.norm(output) 188 | 189 | return output, attn_weights 190 | 191 | 192 | class TransformerDecoder(Module): 193 | r"""TransformerDecoder is a stack of N decoder layers 194 | 195 | Args: 196 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 197 | num_layers: the number of sub-decoder-layers in the decoder (required). 198 | norm: the layer normalization component (optional). 199 | 200 | Examples:: 201 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 202 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 203 | >>> memory = torch.rand(10, 32, 512) 204 | >>> tgt = torch.rand(20, 32, 512) 205 | >>> out = transformer_decoder(tgt, memory) 206 | """ 207 | __constants__ = ['norm'] 208 | 209 | def __init__(self, decoder_layer, num_layers, norm=None): 210 | super(TransformerDecoder, self).__init__() 211 | # self.layers = _get_clones(decoder_layer, num_layers) 212 | self.layer = decoder_layer 213 | self.num_layers = num_layers 214 | self.norm = norm 215 | 216 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 217 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 218 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 219 | r"""Pass the inputs (and mask) through the decoder layer in turn. 220 | 221 | Args: 222 | tgt: the sequence to the decoder (required). 223 | memory: the sequence from the last layer of the encoder (required). 224 | tgt_mask: the mask for the tgt sequence (optional). 225 | memory_mask: the mask for the memory sequence (optional). 226 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 227 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 228 | 229 | Shape: 230 | see the docs in Transformer class. 231 | """ 232 | output = tgt 233 | 234 | # for mod in self.layers: 235 | for _ in range(self.num_layers): 236 | # output = mod(output, memory, tgt_mask=tgt_mask, 237 | # memory_mask=memory_mask, 238 | # tgt_key_padding_mask=tgt_key_padding_mask, 239 | # memory_key_padding_mask=memory_key_padding_mask) 240 | output = self.layer(output, memory, tgt_mask=tgt_mask, 241 | memory_mask=memory_mask, 242 | tgt_key_padding_mask=tgt_key_padding_mask, 243 | memory_key_padding_mask=memory_key_padding_mask) 244 | 245 | if self.norm is not None: 246 | output = self.norm(output) 247 | 248 | return output 249 | 250 | 251 | class TransformerEncoderLayer(Module): 252 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 253 | This standard encoder layer is based on the paper "Attention Is All You Need". 254 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 255 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 256 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 257 | in a different way during application. 258 | 259 | Args: 260 | d_model: the number of expected features in the input (required). 261 | nhead: the number of heads in the multiheadattention models (required). 262 | dim_feedforward: the dimension of the feedforward network model (default=2048). 263 | dropout: the dropout value (default=0.1). 264 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 265 | 266 | Examples:: 267 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 268 | >>> src = torch.rand(10, 32, 512) 269 | >>> out = encoder_layer(src) 270 | """ 271 | 272 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 273 | super(TransformerEncoderLayer, self).__init__() 274 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 275 | # Implementation of Feedforward model 276 | self.linear1 = Linear(d_model, dim_feedforward) 277 | self.dropout = Dropout(dropout) 278 | self.linear2 = Linear(dim_feedforward, d_model) 279 | 280 | self.norm1 = LayerNorm(d_model) 281 | self.norm2 = LayerNorm(d_model) 282 | self.dropout1 = Dropout(dropout) 283 | self.dropout2 = Dropout(dropout) 284 | 285 | self.activation = _get_activation_fn(activation) 286 | 287 | def __setstate__(self, state): 288 | if 'activation' not in state: 289 | state['activation'] = F.relu 290 | super(TransformerEncoderLayer, self).__setstate__(state) 291 | 292 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 293 | r"""Pass the input through the encoder layer. 294 | 295 | Args: 296 | src: the sequence to the encoder layer (required). 297 | src_mask: the mask for the src sequence (optional). 298 | src_key_padding_mask: the mask for the src keys per batch (optional). 299 | 300 | Shape: 301 | see the docs in Transformer class. 302 | """ 303 | src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask, 304 | key_padding_mask=src_key_padding_mask) 305 | src = src + self.dropout1(src2) 306 | src = self.norm1(src) 307 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 308 | src = src + self.dropout2(src2) 309 | src = self.norm2(src) 310 | return src, attn_weights 311 | 312 | 313 | class TransformerDecoderLayer(Module): 314 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 315 | This standard decoder layer is based on the paper "Attention Is All You Need". 316 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 317 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 318 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 319 | in a different way during application. 320 | 321 | Args: 322 | d_model: the number of expected features in the input (required). 323 | nhead: the number of heads in the multiheadattention models (required). 324 | dim_feedforward: the dimension of the feedforward network model (default=2048). 325 | dropout: the dropout value (default=0.1). 326 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 327 | 328 | Examples:: 329 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 330 | >>> memory = torch.rand(10, 32, 512) 331 | >>> tgt = torch.rand(20, 32, 512) 332 | >>> out = decoder_layer(tgt, memory) 333 | """ 334 | 335 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 336 | super(TransformerDecoderLayer, self).__init__() 337 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 338 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 339 | # Implementation of Feedforward model 340 | self.linear1 = Linear(d_model, dim_feedforward) 341 | self.dropout = Dropout(dropout) 342 | self.linear2 = Linear(dim_feedforward, d_model) 343 | 344 | self.norm1 = LayerNorm(d_model) 345 | self.norm2 = LayerNorm(d_model) 346 | self.norm3 = LayerNorm(d_model) 347 | self.dropout1 = Dropout(dropout) 348 | self.dropout2 = Dropout(dropout) 349 | self.dropout3 = Dropout(dropout) 350 | 351 | self.activation = _get_activation_fn(activation) 352 | 353 | def __setstate__(self, state): 354 | if 'activation' not in state: 355 | state['activation'] = F.relu 356 | super(TransformerDecoderLayer, self).__setstate__(state) 357 | 358 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, 359 | tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 360 | r"""Pass the inputs (and mask) through the decoder layer. 361 | 362 | Args: 363 | tgt: the sequence to the decoder layer (required). 364 | memory: the sequence from the last layer of the encoder (required). 365 | tgt_mask: the mask for the tgt sequence (optional). 366 | memory_mask: the mask for the memory sequence (optional). 367 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 368 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 369 | 370 | Shape: 371 | see the docs in Transformer class. 372 | """ 373 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 374 | key_padding_mask=tgt_key_padding_mask)[0] 375 | tgt = tgt + self.dropout1(tgt2) 376 | tgt = self.norm1(tgt) 377 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 378 | key_padding_mask=memory_key_padding_mask)[0] 379 | tgt = tgt + self.dropout2(tgt2) 380 | tgt = self.norm2(tgt) 381 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 382 | tgt = tgt + self.dropout3(tgt2) 383 | tgt = self.norm3(tgt) 384 | return tgt 385 | 386 | 387 | def _get_clones(module, N): 388 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 389 | 390 | 391 | def _get_activation_fn(activation): 392 | if activation == "relu": 393 | return F.relu 394 | elif activation == "gelu": 395 | return F.gelu 396 | 397 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 398 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MTCN (c) by Evangelos Kazakos and Jaesung Huh 2 | 3 | MTCN is licensed under a 4 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | 6 | 7 | Attribution-NonCommercial-ShareAlike 4.0 International 8 | 9 | ======================================================================= 10 | 11 | Creative Commons Corporation ("Creative Commons") is not a law firm and 12 | does not provide legal services or legal advice. Distribution of 13 | Creative Commons public licenses does not create a lawyer-client or 14 | other relationship. Creative Commons makes its licenses and related 15 | information available on an "as-is" basis. Creative Commons gives no 16 | warranties regarding its licenses, any material licensed under their 17 | terms and conditions, or any related information. Creative Commons 18 | disclaims all liability for damages resulting from their use to the 19 | fullest extent possible. 20 | 21 | Using Creative Commons Public Licenses 22 | 23 | Creative Commons public licenses provide a standard set of terms and 24 | conditions that creators and other rights holders may use to share 25 | original works of authorship and other material subject to copyright 26 | and certain other rights specified in the public license below. The 27 | following considerations are for informational purposes only, are not 28 | exhaustive, and do not form part of our licenses. 29 | 30 | Considerations for licensors: Our public licenses are 31 | intended for use by those authorized to give the public 32 | permission to use material in ways otherwise restricted by 33 | copyright and certain other rights. Our licenses are 34 | irrevocable. Licensors should read and understand the terms 35 | and conditions of the license they choose before applying it. 36 | Licensors should also secure all rights necessary before 37 | applying our licenses so that the public can reuse the 38 | material as expected. Licensors should clearly mark any 39 | material not subject to the license. This includes other CC- 40 | licensed material, or material used under an exception or 41 | limitation to copyright. More considerations for licensors: 42 | wiki.creativecommons.org/Considerations_for_licensors 43 | 44 | Considerations for the public: By using one of our public 45 | licenses, a licensor grants the public permission to use the 46 | licensed material under specified terms and conditions. If 47 | the licensor's permission is not necessary for any reason--for 48 | example, because of any applicable exception or limitation to 49 | copyright--then that use is not regulated by the license. Our 50 | licenses grant only permissions under copyright and certain 51 | other rights that a licensor has authority to grant. Use of 52 | the licensed material may still be restricted for other 53 | reasons, including because others have copyright or other 54 | rights in the material. A licensor may make special requests, 55 | such as asking that all changes be marked or described. 56 | Although not required by our licenses, you are encouraged to 57 | respect those requests where reasonable. More considerations 58 | for the public: 59 | wiki.creativecommons.org/Considerations_for_licensees 60 | 61 | ======================================================================= 62 | 63 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 64 | Public License 65 | 66 | By exercising the Licensed Rights (defined below), You accept and agree 67 | to be bound by the terms and conditions of this Creative Commons 68 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 69 | ("Public License"). To the extent this Public License may be 70 | interpreted as a contract, You are granted the Licensed Rights in 71 | consideration of Your acceptance of these terms and conditions, and the 72 | Licensor grants You such rights in consideration of benefits the 73 | Licensor receives from making the Licensed Material available under 74 | these terms and conditions. 75 | 76 | 77 | Section 1 -- Definitions. 78 | 79 | a. Adapted Material means material subject to Copyright and Similar 80 | Rights that is derived from or based upon the Licensed Material 81 | and in which the Licensed Material is translated, altered, 82 | arranged, transformed, or otherwise modified in a manner requiring 83 | permission under the Copyright and Similar Rights held by the 84 | Licensor. For purposes of this Public License, where the Licensed 85 | Material is a musical work, performance, or sound recording, 86 | Adapted Material is always produced where the Licensed Material is 87 | synched in timed relation with a moving image. 88 | 89 | b. Adapter's License means the license You apply to Your Copyright 90 | and Similar Rights in Your contributions to Adapted Material in 91 | accordance with the terms and conditions of this Public License. 92 | 93 | c. BY-NC-SA Compatible License means a license listed at 94 | creativecommons.org/compatiblelicenses, approved by Creative 95 | Commons as essentially the equivalent of this Public License. 96 | 97 | d. Copyright and Similar Rights means copyright and/or similar rights 98 | closely related to copyright including, without limitation, 99 | performance, broadcast, sound recording, and Sui Generis Database 100 | Rights, without regard to how the rights are labeled or 101 | categorized. For purposes of this Public License, the rights 102 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 103 | Rights. 104 | 105 | e. Effective Technological Measures means those measures that, in the 106 | absence of proper authority, may not be circumvented under laws 107 | fulfilling obligations under Article 11 of the WIPO Copyright 108 | Treaty adopted on December 20, 1996, and/or similar international 109 | agreements. 110 | 111 | f. Exceptions and Limitations means fair use, fair dealing, and/or 112 | any other exception or limitation to Copyright and Similar Rights 113 | that applies to Your use of the Licensed Material. 114 | 115 | g. License Elements means the license attributes listed in the name 116 | of a Creative Commons Public License. The License Elements of this 117 | Public License are Attribution, NonCommercial, and ShareAlike. 118 | 119 | h. Licensed Material means the artistic or literary work, database, 120 | or other material to which the Licensor applied this Public 121 | License. 122 | 123 | i. Licensed Rights means the rights granted to You subject to the 124 | terms and conditions of this Public License, which are limited to 125 | all Copyright and Similar Rights that apply to Your use of the 126 | Licensed Material and that the Licensor has authority to license. 127 | 128 | j. Licensor means the individual(s) or entity(ies) granting rights 129 | under this Public License. 130 | 131 | k. NonCommercial means not primarily intended for or directed towards 132 | commercial advantage or monetary compensation. For purposes of 133 | this Public License, the exchange of the Licensed Material for 134 | other material subject to Copyright and Similar Rights by digital 135 | file-sharing or similar means is NonCommercial provided there is 136 | no payment of monetary compensation in connection with the 137 | exchange. 138 | 139 | l. Share means to provide material to the public by any means or 140 | process that requires permission under the Licensed Rights, such 141 | as reproduction, public display, public performance, distribution, 142 | dissemination, communication, or importation, and to make material 143 | available to the public including in ways that members of the 144 | public may access the material from a place and at a time 145 | individually chosen by them. 146 | 147 | m. Sui Generis Database Rights means rights other than copyright 148 | resulting from Directive 96/9/EC of the European Parliament and of 149 | the Council of 11 March 1996 on the legal protection of databases, 150 | as amended and/or succeeded, as well as other essentially 151 | equivalent rights anywhere in the world. 152 | 153 | n. You means the individual or entity exercising the Licensed Rights 154 | under this Public License. Your has a corresponding meaning. 155 | 156 | 157 | Section 2 -- Scope. 158 | 159 | a. License grant. 160 | 161 | 1. Subject to the terms and conditions of this Public License, 162 | the Licensor hereby grants You a worldwide, royalty-free, 163 | non-sublicensable, non-exclusive, irrevocable license to 164 | exercise the Licensed Rights in the Licensed Material to: 165 | 166 | a. reproduce and Share the Licensed Material, in whole or 167 | in part, for NonCommercial purposes only; and 168 | 169 | b. produce, reproduce, and Share Adapted Material for 170 | NonCommercial purposes only. 171 | 172 | 2. Exceptions and Limitations. For the avoidance of doubt, where 173 | Exceptions and Limitations apply to Your use, this Public 174 | License does not apply, and You do not need to comply with 175 | its terms and conditions. 176 | 177 | 3. Term. The term of this Public License is specified in Section 178 | 6(a). 179 | 180 | 4. Media and formats; technical modifications allowed. The 181 | Licensor authorizes You to exercise the Licensed Rights in 182 | all media and formats whether now known or hereafter created, 183 | and to make technical modifications necessary to do so. The 184 | Licensor waives and/or agrees not to assert any right or 185 | authority to forbid You from making technical modifications 186 | necessary to exercise the Licensed Rights, including 187 | technical modifications necessary to circumvent Effective 188 | Technological Measures. For purposes of this Public License, 189 | simply making modifications authorized by this Section 2(a) 190 | (4) never produces Adapted Material. 191 | 192 | 5. Downstream recipients. 193 | 194 | a. Offer from the Licensor -- Licensed Material. Every 195 | recipient of the Licensed Material automatically 196 | receives an offer from the Licensor to exercise the 197 | Licensed Rights under the terms and conditions of this 198 | Public License. 199 | 200 | b. Additional offer from the Licensor -- Adapted Material. 201 | Every recipient of Adapted Material from You 202 | automatically receives an offer from the Licensor to 203 | exercise the Licensed Rights in the Adapted Material 204 | under the conditions of the Adapter's License You apply. 205 | 206 | c. No downstream restrictions. You may not offer or impose 207 | any additional or different terms or conditions on, or 208 | apply any Effective Technological Measures to, the 209 | Licensed Material if doing so restricts exercise of the 210 | Licensed Rights by any recipient of the Licensed 211 | Material. 212 | 213 | 6. No endorsement. Nothing in this Public License constitutes or 214 | may be construed as permission to assert or imply that You 215 | are, or that Your use of the Licensed Material is, connected 216 | with, or sponsored, endorsed, or granted official status by, 217 | the Licensor or others designated to receive attribution as 218 | provided in Section 3(a)(1)(A)(i). 219 | 220 | b. Other rights. 221 | 222 | 1. Moral rights, such as the right of integrity, are not 223 | licensed under this Public License, nor are publicity, 224 | privacy, and/or other similar personality rights; however, to 225 | the extent possible, the Licensor waives and/or agrees not to 226 | assert any such rights held by the Licensor to the limited 227 | extent necessary to allow You to exercise the Licensed 228 | Rights, but not otherwise. 229 | 230 | 2. Patent and trademark rights are not licensed under this 231 | Public License. 232 | 233 | 3. To the extent possible, the Licensor waives any right to 234 | collect royalties from You for the exercise of the Licensed 235 | Rights, whether directly or through a collecting society 236 | under any voluntary or waivable statutory or compulsory 237 | licensing scheme. In all other cases the Licensor expressly 238 | reserves any right to collect such royalties, including when 239 | the Licensed Material is used other than for NonCommercial 240 | purposes. 241 | 242 | 243 | Section 3 -- License Conditions. 244 | 245 | Your exercise of the Licensed Rights is expressly made subject to the 246 | following conditions. 247 | 248 | a. Attribution. 249 | 250 | 1. If You Share the Licensed Material (including in modified 251 | form), You must: 252 | 253 | a. retain the following if it is supplied by the Licensor 254 | with the Licensed Material: 255 | 256 | i. identification of the creator(s) of the Licensed 257 | Material and any others designated to receive 258 | attribution, in any reasonable manner requested by 259 | the Licensor (including by pseudonym if 260 | designated); 261 | 262 | ii. a copyright notice; 263 | 264 | iii. a notice that refers to this Public License; 265 | 266 | iv. a notice that refers to the disclaimer of 267 | warranties; 268 | 269 | v. a URI or hyperlink to the Licensed Material to the 270 | extent reasonably practicable; 271 | 272 | b. indicate if You modified the Licensed Material and 273 | retain an indication of any previous modifications; and 274 | 275 | c. indicate the Licensed Material is licensed under this 276 | Public License, and include the text of, or the URI or 277 | hyperlink to, this Public License. 278 | 279 | 2. You may satisfy the conditions in Section 3(a)(1) in any 280 | reasonable manner based on the medium, means, and context in 281 | which You Share the Licensed Material. For example, it may be 282 | reasonable to satisfy the conditions by providing a URI or 283 | hyperlink to a resource that includes the required 284 | information. 285 | 3. If requested by the Licensor, You must remove any of the 286 | information required by Section 3(a)(1)(A) to the extent 287 | reasonably practicable. 288 | 289 | b. ShareAlike. 290 | 291 | In addition to the conditions in Section 3(a), if You Share 292 | Adapted Material You produce, the following conditions also apply. 293 | 294 | 1. The Adapter's License You apply must be a Creative Commons 295 | license with the same License Elements, this version or 296 | later, or a BY-NC-SA Compatible License. 297 | 298 | 2. You must include the text of, or the URI or hyperlink to, the 299 | Adapter's License You apply. You may satisfy this condition 300 | in any reasonable manner based on the medium, means, and 301 | context in which You Share Adapted Material. 302 | 303 | 3. You may not offer or impose any additional or different terms 304 | or conditions on, or apply any Effective Technological 305 | Measures to, Adapted Material that restrict exercise of the 306 | rights granted under the Adapter's License You apply. 307 | 308 | 309 | Section 4 -- Sui Generis Database Rights. 310 | 311 | Where the Licensed Rights include Sui Generis Database Rights that 312 | apply to Your use of the Licensed Material: 313 | 314 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 315 | to extract, reuse, reproduce, and Share all or a substantial 316 | portion of the contents of the database for NonCommercial purposes 317 | only; 318 | 319 | b. if You include all or a substantial portion of the database 320 | contents in a database in which You have Sui Generis Database 321 | Rights, then the database in which You have Sui Generis Database 322 | Rights (but not its individual contents) is Adapted Material, 323 | including for purposes of Section 3(b); and 324 | 325 | c. You must comply with the conditions in Section 3(a) if You Share 326 | all or a substantial portion of the contents of the database. 327 | 328 | For the avoidance of doubt, this Section 4 supplements and does not 329 | replace Your obligations under this Public License where the Licensed 330 | Rights include other Copyright and Similar Rights. 331 | 332 | 333 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 334 | 335 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 336 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 337 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 338 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 339 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 340 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 341 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 342 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 343 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 344 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 345 | 346 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 347 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 348 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 349 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 350 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 351 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 352 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 353 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 354 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 355 | 356 | c. The disclaimer of warranties and limitation of liability provided 357 | above shall be interpreted in a manner that, to the extent 358 | possible, most closely approximates an absolute disclaimer and 359 | waiver of all liability. 360 | 361 | 362 | Section 6 -- Term and Termination. 363 | 364 | a. This Public License applies for the term of the Copyright and 365 | Similar Rights licensed here. However, if You fail to comply with 366 | this Public License, then Your rights under this Public License 367 | terminate automatically. 368 | 369 | b. Where Your right to use the Licensed Material has terminated under 370 | Section 6(a), it reinstates: 371 | 372 | 1. automatically as of the date the violation is cured, provided 373 | it is cured within 30 days of Your discovery of the 374 | violation; or 375 | 376 | 2. upon express reinstatement by the Licensor. 377 | 378 | For the avoidance of doubt, this Section 6(b) does not affect any 379 | right the Licensor may have to seek remedies for Your violations 380 | of this Public License. 381 | 382 | c. For the avoidance of doubt, the Licensor may also offer the 383 | Licensed Material under separate terms or conditions or stop 384 | distributing the Licensed Material at any time; however, doing so 385 | will not terminate this Public License. 386 | 387 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 388 | License. 389 | 390 | 391 | Section 7 -- Other Terms and Conditions. 392 | 393 | a. The Licensor shall not be bound by any additional or different 394 | terms or conditions communicated by You unless expressly agreed. 395 | 396 | b. Any arrangements, understandings, or agreements regarding the 397 | Licensed Material not stated herein are separate from and 398 | independent of the terms and conditions of this Public License. 399 | 400 | 401 | Section 8 -- Interpretation. 402 | 403 | a. For the avoidance of doubt, this Public License does not, and 404 | shall not be interpreted to, reduce, limit, restrict, or impose 405 | conditions on any use of the Licensed Material that could lawfully 406 | be made without permission under this Public License. 407 | 408 | b. To the extent possible, if any provision of this Public License is 409 | deemed unenforceable, it shall be automatically reformed to the 410 | minimum extent necessary to make it enforceable. If the provision 411 | cannot be reformed, it shall be severed from this Public License 412 | without affecting the enforceability of the remaining terms and 413 | conditions. 414 | 415 | c. No term or condition of this Public License will be waived and no 416 | failure to comply consented to unless expressly agreed to by the 417 | Licensor. 418 | 419 | d. Nothing in this Public License constitutes or may be interpreted 420 | as a limitation upon, or waiver of, any privileges and immunities 421 | that apply to the Licensor or You, including from the legal 422 | processes of any jurisdiction or authority. 423 | 424 | ======================================================================= 425 | 426 | Creative Commons is not a party to its public 427 | licenses. Notwithstanding, Creative Commons may elect to apply one of 428 | its public licenses to material it publishes and in those instances 429 | will be considered the “Licensor.” The text of the Creative Commons 430 | public licenses is dedicated to the public domain under the CC0 Public 431 | Domain Dedication. Except for the limited purpose of indicating that 432 | material is shared under a Creative Commons public license or as 433 | otherwise permitted by the Creative Commons policies published at 434 | creativecommons.org/policies, Creative Commons does not authorize the 435 | use of the trademark "Creative Commons" or any other trademark or logo 436 | of Creative Commons without its prior written consent including, 437 | without limitation, in connection with any unauthorized modifications 438 | to any of its public licenses or any other arrangements, 439 | understandings, or agreements concerning use of licensed material. For 440 | the avoidance of doubt, this paragraph does not form part of the 441 | public licenses. 442 | 443 | Creative Commons may be contacted at creativecommons.org. 444 | --------------------------------------------------------------------------------