├── .gitignore ├── README.md ├── configs ├── msrvtt_default.yml └── msvd_default.yml ├── dataloader.py ├── eval.py ├── models ├── SemanticLSTM.py └── __init__.py ├── optim.py ├── opts.py ├── requirements.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | Notes.md 2 | test.py 3 | test_mem.py 4 | data 5 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantics-AssistedVideoCaptioning.Pytorch 2 | This is unofficial implementation of [Semantics-Assisted Video Captioning Model Trained with Scheduled Sampling Strategy](https://arxiv.org/abs/1909.00121). 3 | You can find official tensorflow implementation [here](https://github.com/WingsBrokenAngel/Semantics-AssistedVideoCaptioning). 4 | 5 | ## Dependency 6 | * python 3.6 7 | * pytorch 8 | * pycocoevalcap 9 | ```shell script 10 | pip install -r requirements.txt 11 | ``` 12 | ## Steps 13 | ### Data 14 | Please download data in official [repo](https://github.com/WingsBrokenAngel/Semantics-AssistedVideoCaptioning) and put them 15 | all in ./data 16 | ### Training 17 | If you want to train on msvd with default parameters: 18 | ```shell script 19 | python train.py --cfg configs/msvd_default.yml --savedir saved_results --exp_name anonymous_run 20 | ``` 21 | Remember to create savedir by `mkdir`. `--exp_name` is the name you give to this run. 22 | 23 | For training on msrvtt, just change `--cfg` to `configs/msrvtt_default.yml`. It 24 | takes about 90 min to train on msvd, 5h to train on msr-vtt (on GTX 1080Ti). 25 | 26 | For more details about configs, please see `opts.py` and yaml files in ./configs 27 | ### Babysitting 28 | You can see training process by tensorboard. 29 | ```shell script 30 | tensorboard --logdir saved_results --port my_port --host 0.0.0.0 31 | ``` 32 | ### Evaluation 33 | ```shell script 34 | python eval.py --savedir saved_results --exp_name anonymous_run --max_sent_len 20 --model_path path_of_model_to_eval 35 | ``` 36 | If you don't specify `--model_path`, best model will be evaluated. 37 | ## Results 38 | Results of my implementation are not chosen. I just run once for each dataset. 39 | My implementation is comparable to official claim. 40 | ### MSVD 41 | |Model|B-4|R|M|C| 42 | |---|---|---|---|---| 43 | |official|61.8|76.8|37.8|103.0| 44 | |mine|61.2|76.6|38.5|106.5| 45 | 46 | ### MSR-VTT 47 | |Model|B-4|R|M|C| 48 | |---|---|---|---|---| 49 | |official|43.8|62.4|28.9|51.4| 50 | |mine|44.4|62.7|28.8|50.7| 51 | 52 | ## Differences 53 | ### Adam optimizer 54 | Since Tensorflow and Pytorch implement Adam differently. I also offer 55 | tensorflow version of Adam in optim.py. But I found they perform 56 | comparable. So I choose Pytorch Adam by default. See more detials in 57 | reference. 58 | ### model choice 59 | Official implementation choose best model by a weighted sum of all scores. 60 | I just choose model of best cider on validation set. 61 | ### dropout position 62 | Official implementation do dropout after schedule sampling. I do it before. 63 | 64 | ## TODO(or Neverdo) 65 | * beam search 66 | * reinforcement learning 67 | 68 | ## Acknowledgement 69 | Thank for the original tensorflow implementation. 70 | 71 | ## References 72 | * adam problem 73 | * https://discuss.pytorch.org/t/pytorch-adam-vs-tensorflow-adam/74471 74 | * https://stackoverflow.com/questions/57824804/epsilon-parameter-in-adam-opitmizer 75 | * https://github.com/tensorflow/tensorflow/issues/35102 76 | * [official implementation](https://github.com/WingsBrokenAngel/Semantics-AssistedVideoCaptioning) 77 | -------------------------------------------------------------------------------- /configs/msrvtt_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | dataset: msrvtt 3 | corpus: ./data/msrvtt_corpus.pkl 4 | reseco: ./data/msrvtt_resnext_eco.npy 5 | tag: ./data/msrvtt_e800_tag_feats.npy 6 | ref: ./data/msrvtt_ref.pkl 7 | val_start_idx: 6513 8 | val_end_idx: 7010 9 | test_start_idx: 7010 10 | test_end_idx: 10000 11 | num_workers: 0 12 | 13 | # model 14 | model: standard 15 | embedding_dim: 300 16 | hidden_dim: 512 17 | 18 | # learning 19 | lr: 0.0004 20 | lr_decay: 0.316 21 | lr_decay_every: 10 22 | batch_size: 64 23 | n_epoch: 50 24 | max_sent_len: 20 25 | 26 | # checkpoint 27 | exp_name: anonymous_run 28 | savedir: saved_results2 29 | 30 | # sample 31 | schedule_sample_method: multinomial 32 | schedule_sample_prob: 0 33 | schedule_sample_ratio: 0.008 34 | -------------------------------------------------------------------------------- /configs/msvd_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | dataset: msvd 3 | corpus: ./data/msvd_corpus.pkl 4 | reseco: ./data/msvd_resnext_eco.npy 5 | tag: ./data/msvd_semantic_tag_e1000.npy 6 | ref: ./data/msvd_ref.pkl 7 | val_start_idx: 1200 8 | val_end_idx: 1300 9 | test_start_idx: 1300 10 | test_end_idx: 1970 11 | num_workers: 0 12 | 13 | # model 14 | model: standard 15 | embedding_dim: 300 16 | hidden_dim: 512 17 | 18 | # learning 19 | lr: 0.0004 20 | lr_decay: 1 21 | lr_decay_every: 999999 22 | batch_size: 64 23 | n_epoch: 50 24 | max_sent_len: 20 25 | 26 | # checkpoint 27 | exp_name: anonymous_run 28 | savedir: saved_results 29 | 30 | # sample 31 | schedule_sample_method: multinomial 32 | schedule_sample_prob: 0 33 | schedule_sample_ratio: 0.008 34 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import os 3 | import numpy as np 4 | import pickle 5 | import torch 6 | 7 | 8 | def collate_fn(batch, split, padding_idx): 9 | if split == 'train': 10 | max_len = max([len(caption) for _, _, caption in batch]) 11 | batched_captions = np.ones((len(batch), max_len), dtype='int') * padding_idx 12 | for i, (_, _, caption) in enumerate(batch): 13 | batched_captions[i, :len(caption)] = caption 14 | batched_captions = torch.from_numpy(batched_captions).long() 15 | else: 16 | batched_captions = [caption for _, _, caption in batch] 17 | 18 | feats = [feat for feat, _, _ in batch] 19 | tags = [tag for _, tag, _ in batch] 20 | batched_feats = torch.from_numpy(np.stack(feats)) 21 | batched_tags = torch.from_numpy(np.stack(tags)) 22 | 23 | return batched_feats, batched_tags, batched_captions 24 | 25 | 26 | class DataManager: 27 | def __init__(self, args): 28 | self.features = np.load(args.reseco) 29 | self.tags = np.load(args.tag) 30 | self.corpus = pickle.load(open(args.corpus,'rb')) 31 | self.refs = pickle.load(open(args.ref,'rb')) 32 | self.max_sent_len = args.max_sent_len 33 | 34 | self.val_start_idx = args.val_start_idx 35 | self.val_end_idx = args.val_end_idx 36 | self.test_start_idx = args.test_start_idx 37 | self.test_end_idx = args.test_end_idx 38 | 39 | self.idx2word = self.corpus[4] 40 | self.idx2word[len(self.idx2word)] = '' 41 | self.idx2word[len(self.idx2word)] = '' 42 | self.word2idx = {value: key for key, value in self.idx2word.items()} 43 | 44 | def get_train(self): 45 | train_data = self.corpus[0] 46 | train_feats = self.features 47 | train_tags = self.tags 48 | return TrainData(train_data, train_feats, train_tags, max_sent_len=self.max_sent_len, word2idx=self.word2idx) 49 | 50 | def get_val(self): 51 | refs = self.refs[1] 52 | return ValData(self.val_start_idx, self.val_end_idx, self.features, self.tags, refs) 53 | 54 | def get_test(self): 55 | refs = self.refs[2] 56 | return ValData(self.test_start_idx, self.test_end_idx, self.features, self.tags, refs) 57 | 58 | def split(self): 59 | return self.get_train(), self.get_val(), self.get_test() 60 | 61 | def decode(self, logsoftmax): 62 | outputs = [] 63 | batch_size = logsoftmax.shape[0] 64 | argmax = logsoftmax.argmax(-1) 65 | for i in range(batch_size): 66 | seq = argmax[i] 67 | sentence = [] 68 | for j in range(seq.shape[0]): 69 | if seq[j].item() == self.word2idx[''] or seq[j].item() == self.word2idx['']: 70 | break 71 | sentence.append(self.idx2word[seq[j].item()]) 72 | sentence = ' '.join(sentence) 73 | outputs.append(sentence) 74 | return outputs 75 | 76 | 77 | class TrainData(Dataset): 78 | def __init__(self, data, feats, tags, max_sent_len, word2idx): 79 | self.data = data 80 | self.feats = feats 81 | self.tags = tags 82 | self.max_sent_len = max_sent_len 83 | self.word2idx = word2idx 84 | 85 | def __len__(self): 86 | return len(self.data[0]) 87 | 88 | def __getitem__(self, idx): 89 | caption = self.data[0][idx] 90 | video_index = self.data[1][idx] 91 | feat = self.feats[video_index] 92 | tags = self.tags[video_index] 93 | return feat, tags, caption 94 | 95 | 96 | class ValData(Dataset): 97 | def __init__(self, start_idx, end_idx, feats, tags, refs): 98 | self.start_idx = start_idx 99 | self.end_idx = end_idx 100 | self.feats = feats 101 | self.tags = tags 102 | self.refs = refs 103 | 104 | def __len__(self): 105 | return self.end_idx - self.start_idx 106 | 107 | def __getitem__(self, idx): 108 | feat = self.feats[idx + self.start_idx] 109 | tags = self.tags[idx + self.start_idx] 110 | caption = self.refs[idx] 111 | return feat, tags, caption # , caption 112 | 113 | 114 | if __name__ == '__main__': 115 | dm = DataManager('data') 116 | train, val, test = dm.split() 117 | # print(train[0]) 118 | # print(val[0]) 119 | # print(test[0]) 120 | train_loader = DataLoader(train, batch_size=10, shuffle=True, collate_fn=collate_fn) 121 | val_loader = DataLoader(val, batch_size=5, collate_fn=collate_fn) 122 | for item in train_loader: 123 | break 124 | 125 | for item in val_loader: 126 | break 127 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from tqdm import tqdm 4 | import torch.nn as nn 5 | from torch.optim import Adam 6 | # from optim import Adam 7 | from dataloader import DataManager, collate_fn 8 | from torch.utils.data import DataLoader 9 | from functools import partial 10 | from pycocoevalcap.bleu.bleu import Bleu 11 | from pycocoevalcap.rouge.rouge import Rouge 12 | from pycocoevalcap.cider.cider import Cider 13 | from pycocoevalcap.meteor.meteor import Meteor 14 | from tensorboardX import SummaryWriter 15 | import os 16 | from models import init_model 17 | from utils import NLLLossWithLength 18 | from torch.optim.lr_scheduler import ExponentialLR 19 | from opts import get_eval_args 20 | from train import evaluate 21 | 22 | if __name__ == '__main__': 23 | # get configs 24 | args = get_eval_args() 25 | print(args) 26 | 27 | # load data 28 | dm = DataManager(args) 29 | 30 | # prepare model 31 | model = init_model(args, dm) 32 | model = model.cuda() 33 | model.load_state_dict(torch.load(args.model_path)) 34 | 35 | # split data 36 | _, _, test_data = dm.split() 37 | test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers, 38 | collate_fn=partial(collate_fn, split='test', padding_idx=dm.word2idx[''])) 39 | print('Start Video-Captioning Evaluation') 40 | 41 | test_score = evaluate(0, model, test_loader, dm, maxlen=args.max_sent_len, split='Test', verbose=True) 42 | print('Test score:') 43 | print(test_score) 44 | print('Done') 45 | -------------------------------------------------------------------------------- /models/SemanticLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class SemanticLSTM(nn.Module): 8 | def __init__(self, vocab_size, embedding_dim, feat_size, tag_size, hidden_size, bos_idx=0, eos_idx=1, padding_idx=2, 9 | embedding_array=None, freeze_embedding=True, schedule_sample_prob=0, schedule_sample_method='greedy'): 10 | super(SemanticLSTM, self).__init__() 11 | self.feat_size = feat_size 12 | self.tag_size = tag_size 13 | self.vocab_size = vocab_size 14 | self.embedding_dim = embedding_dim 15 | self.embedding_dropout = nn.Dropout(p=0.5) 16 | self.input_size = feat_size + tag_size + embedding_dim 17 | self.hidden_size = hidden_size 18 | self.schedule_sample_prob = schedule_sample_prob 19 | self.schedule_sample_method = schedule_sample_method 20 | 21 | self.bos_idx = bos_idx 22 | self.eos_idx = eos_idx 23 | self.padding_idx = padding_idx 24 | if embedding_array is not None: 25 | embedding_array = torch.FloatTensor(embedding_array) 26 | embedding_array = torch.cat([embedding_array, torch.zeros((2, embedding_dim))], dim=0) # for bos and pad 27 | self.word_embed = nn.Embedding.from_pretrained(embedding_array, freeze=freeze_embedding) 28 | else: 29 | self.word_embed = nn.Embedding(vocab_size, embedding_dim) 30 | self.feat2input = nn.Linear(self.feat_size, self.embedding_dim, bias=False) 31 | 32 | self.feat_dropout = nn.Dropout(p=0.5) 33 | self.tag_dropout = nn.Dropout(p=0.5) 34 | 35 | self.tag2lstm1 = nn.Linear(self.tag_size, 4 * self.hidden_size, bias=False) 36 | self.feat2lstm = nn.Linear(self.feat_size, 4 * self.hidden_size, bias=False) 37 | self.tag2lstm2 = nn.Linear(self.tag_size, 4 * self.hidden_size, bias=False) 38 | 39 | self.word2lstm = nn.Linear(self.embedding_dim, 4 * self.hidden_size, bias=False) 40 | 41 | self.fc_i = nn.Linear(2 * self.hidden_size, self.hidden_size) 42 | self.fc_f = nn.Linear(2 * self.hidden_size, self.hidden_size) 43 | self.fc_o = nn.Linear(2 * self.hidden_size, self.hidden_size) 44 | self.fc_c = nn.Linear(2 * self.hidden_size, self.hidden_size) 45 | 46 | self.fc_hidden_state_i = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 47 | self.fc_hidden_state_f = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 48 | self.fc_hidden_state_o = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 49 | self.fc_hidden_state_c = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 50 | self.fc_tag_i = nn.Linear(self.tag_size, self.hidden_size, bias=False) 51 | self.fc_tag_f = nn.Linear(self.tag_size, self.hidden_size, bias=False) 52 | self.fc_tag_o = nn.Linear(self.tag_size, self.hidden_size, bias=False) 53 | self.fc_tag_c = nn.Linear(self.tag_size, self.hidden_size, bias=False) 54 | self.fc_both_i = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 55 | self.fc_both_f = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 56 | self.fc_both_o = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 57 | self.fc_both_c = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 58 | 59 | self.h_dropout = nn.Dropout(0.5) 60 | self.word2logit = nn.Linear(self.embedding_dim, self.hidden_size) 61 | self.logit_bias = nn.Parameter(torch.zeros((vocab_size - 2,)), requires_grad=True) 62 | self._init() 63 | 64 | def _init(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Linear): 67 | nn.init.xavier_normal_(m.weight) 68 | if m.bias is not None: 69 | nn.init.constant_(m.bias, 0) 70 | 71 | def init_state(self, x): 72 | bs = x.shape[0] 73 | return (torch.zeros((bs, self.hidden_size), device=x.device), 74 | torch.zeros((bs, self.hidden_size), device=x.device)) 75 | 76 | def core(self, word_embedding, feats, tags, tmps, state): 77 | tmp2_i, tmp2_f, tmp2_o, tmp2_c = torch.split(tmps['tmp2'], self.hidden_size, dim=-1) 78 | tmp3_i, tmp3_f, tmp3_o, tmp3_c = torch.split(tmps['tmp3'], self.hidden_size, dim=-1) 79 | tmp4_i, tmp4_f, tmp4_o, tmp4_c = torch.split(tmps['tmp4'], self.hidden_size, dim=-1) 80 | 81 | tmp1_i, tmp1_f, tmp1_o, tmp1_c = torch.split(self.word2lstm(word_embedding), self.hidden_size, dim=-1) 82 | 83 | tmp_i = torch.cat([tmp1_i * tmp2_i, tmp3_i * tmp4_i], dim=-1) 84 | tmp_f = torch.cat([tmp1_f * tmp2_f, tmp3_f * tmp4_f], dim=-1) 85 | tmp_o = torch.cat([tmp1_o * tmp2_o, tmp3_o * tmp4_o], dim=-1) 86 | tmp_c = torch.cat([tmp1_c * tmp2_c, tmp3_c * tmp4_c], dim=-1) 87 | input_i = self.fc_i(tmp_i) 88 | input_f = self.fc_f(tmp_f) 89 | input_o = self.fc_o(tmp_o) 90 | input_c = self.fc_c(tmp_c) 91 | 92 | preact_i = self.fc_both_i(self.fc_hidden_state_i(state[0]) * self.fc_tag_i(tags)) + input_i 93 | preact_f = self.fc_both_f(self.fc_hidden_state_f(state[0]) * self.fc_tag_f(tags)) + input_f 94 | preact_o = self.fc_both_o(self.fc_hidden_state_o(state[0]) * self.fc_tag_o(tags)) + input_o 95 | preact_c = self.fc_both_c(self.fc_hidden_state_c(state[0]) * self.fc_tag_c(tags)) + input_c 96 | 97 | i = torch.sigmoid(preact_i) 98 | f = torch.sigmoid(preact_f) 99 | o = torch.sigmoid(preact_o) 100 | c = torch.tanh(preact_c) 101 | 102 | c = f * state[1] + i * c 103 | h = o * torch.tanh(c) 104 | return (h, c) 105 | 106 | def prepare_feats(self, feats, tags, seq): 107 | feats = self.feat_dropout(feats) 108 | tags = self.tag_dropout(tags) 109 | 110 | tmps = {} 111 | tmps['tmp2'] = self.tag2lstm1(tags) 112 | tmps['tmp3'] = self.feat2lstm(feats) 113 | tmps['tmp4'] = self.tag2lstm2(tags) 114 | return feats, tags, seq, tmps 115 | 116 | def forward(self, feats, tags, seq): 117 | state = self.init_state(feats) 118 | bs = feats.shape[0] 119 | outputs = [] 120 | 121 | feats, tags, seq, tmps = self.prepare_feats(feats, tags, seq) 122 | # actually we do not use bos token, we use visual feats instead 123 | # this is just for put all code into a single loop 124 | bos = torch.ones(bs, device=feats.device).long() * self.bos_idx 125 | seq = torch.cat([bos.unsqueeze(1), seq], dim=1) 126 | # use visual feats at first step 127 | for i in range(seq.shape[1]): 128 | rand = np.random.uniform(0, 1, (bs,)) 129 | if (seq[:, i] == self.eos_idx).sum() + (seq[:, i] == self.padding_idx).sum() == bs: 130 | break 131 | if i is 0: # start token 132 | word_embedding = self.feat2input(feats) 133 | elif self.schedule_sample_prob != 0 and ( 134 | rand < self.schedule_sample_prob).any(): # schedula sample 135 | xt = seq[:, i].data.clone() 136 | index = rand < self.schedule_sample_prob 137 | last_output = outputs[-1].detach() 138 | if self.schedule_sample_method == 'greedy': 139 | words = last_output.argmax(-1) 140 | elif self.schedule_sample_method == 'multinomial': 141 | distribution = torch.exp(last_output) 142 | words = torch.multinomial(distribution, 1).squeeze(-1) 143 | else: 144 | raise NotImplementedError 145 | xt[index] = words[index] 146 | word_embedding = self.embedding_dropout(self.word_embed(xt)) 147 | else: # Teacher Forcings 148 | word_embedding = self.embedding_dropout(self.word_embed(seq[:, i])) 149 | state = self.core(word_embedding, feats, tags, tmps, state) 150 | logit_weight = torch.matmul(self.word2logit.weight, self.word_embed.weight[:-2, :].T) 151 | logit = F.linear(self.h_dropout(state[0]), logit_weight.T, self.logit_bias) 152 | outputs.append(logit) 153 | res = torch.stack(outputs, dim=1) 154 | res = F.log_softmax(res, dim=-1) 155 | return res 156 | 157 | def sample(self, feats, tags, maxlen, mode='greedy'): 158 | state = self.init_state(feats) 159 | outputs = [] 160 | bs = feats.shape[0] 161 | 162 | feats, tags, seq, tmps = self.prepare_feats(feats, tags, None) 163 | # bos = torch.ones(bs, device=feats.device).long() * self.bos_idx 164 | is_finished = torch.zeros(bs) 165 | for i in range(maxlen + 1): 166 | if i is 0: # use visual feats at first step 167 | word_embedding = self.feat2input(feats) 168 | elif mode == 'greedy': 169 | last_output = outputs[-1] 170 | last_token = last_output.argmax(-1) 171 | is_finished[last_token == self.eos_idx] = 1 172 | if is_finished.sum() == bs: # all finished sample 173 | break 174 | word_embedding = self.embedding_dropout(self.word_embed(last_token)) 175 | else: 176 | raise NotImplementedError 177 | 178 | state = self.core(word_embedding, feats, tags, tmps, state) 179 | logit_weight = torch.matmul(self.word2logit.weight, self.word_embed.weight[:-2, :].T) 180 | logit = F.linear(self.h_dropout(state[0]), logit_weight.T, self.logit_bias) 181 | outputs.append(logit) 182 | res = torch.stack(outputs, dim=1) 183 | res = F.log_softmax(res, dim=-1) 184 | return res 185 | 186 | 187 | if __name__ == '__main__': 188 | model = SemanticLSTM(10, 20, 120, 100) 189 | feats = torch.randn(16, 50) 190 | tags = torch.randn(16, 50) 191 | seq = torch.randint(0, 10, (16, 10)) 192 | out = model(feats, tags, seq) 193 | print(out.shape) 194 | sample_out = model.sample(feats, tags, 20) 195 | print(sample_out.shape) 196 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .SemanticLSTM import SemanticLSTM 2 | 3 | 4 | def init_model(args, dm): 5 | if args.model == 'standard': 6 | model = SemanticLSTM(len(dm.idx2word), args.embedding_dim, 1536 + 2048, 300, 7 | args.hidden_dim, 8 | bos_idx=dm.word2idx[''], 9 | eos_idx=dm.word2idx[''], 10 | padding_idx=dm.word2idx[''], 11 | embedding_array=dm.corpus[5], 12 | schedule_sample_prob=args.schedule_sample_prob, 13 | schedule_sample_method=args.schedule_sample_method) 14 | else: 15 | raise NotImplementedError 16 | return model 17 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class Adam(Optimizer): 7 | r"""Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, amsgrad=False): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if not 0.0 <= eps: 35 | raise ValueError("Invalid epsilon value: {}".format(eps)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | if not 0.0 <= weight_decay: 41 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, amsgrad=amsgrad) 44 | super(Adam, self).__init__(params, defaults) 45 | 46 | def __setstate__(self, state): 47 | super(Adam, self).__setstate__(state) 48 | for group in self.param_groups: 49 | group.setdefault('amsgrad', False) 50 | 51 | @torch.no_grad() 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | 55 | Arguments: 56 | closure (callable, optional): A closure that reevaluates the model 57 | and returns the loss. 58 | """ 59 | loss = None 60 | if closure is not None: 61 | with torch.enable_grad(): 62 | loss = closure() 63 | 64 | for group in self.param_groups: 65 | for p in group['params']: 66 | if p.grad is None: 67 | continue 68 | grad = p.grad 69 | if grad.is_sparse: 70 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 71 | amsgrad = group['amsgrad'] 72 | 73 | state = self.state[p] 74 | 75 | # State initialization 76 | if len(state) == 0: 77 | state['step'] = 0 78 | # Exponential moving average of gradient values 79 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 80 | # Exponential moving average of squared gradient values 81 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 82 | if amsgrad: 83 | # Maintains max of all exp. moving avg. of sq. grad. values 84 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 85 | 86 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 87 | if amsgrad: 88 | max_exp_avg_sq = state['max_exp_avg_sq'] 89 | beta1, beta2 = group['betas'] 90 | 91 | state['step'] += 1 92 | bias_correction1 = 1 - beta1 ** state['step'] 93 | bias_correction2 = 1 - beta2 ** state['step'] 94 | 95 | if group['weight_decay'] != 0: 96 | grad = grad.add(p, alpha=group['weight_decay']) 97 | 98 | # Decay the first and second moment running average coefficient 99 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 100 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) / math.sqrt(bias_correction2) 108 | 109 | step_size = group['lr'] / bias_correction1 110 | 111 | p.addcdiv_(exp_avg, denom, value=-step_size) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | 5 | data_keys = ['dataset', 'corpus', 'reseco', 'tag', 'ref', 'val_start_idx', 'val_end_idx', 6 | 'test_start_idx', 'test_end_idx', 'num_workers'] 7 | 8 | model_keys = ['model', 'embedding_dim', 'hidden_dim'] 9 | 10 | checkpoint_keys = ['exp_name', 'savedir'] 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description='Video Captioning Arguments', argument_default=argparse.SUPPRESS) 15 | 16 | # data 17 | parser.add_argument('--dataset', type=str) 18 | parser.add_argument('--corpus', type=str) 19 | parser.add_argument('--reseco', type=str) 20 | parser.add_argument('--tag', type=str) 21 | parser.add_argument('--ref', type=str) 22 | parser.add_argument('--num_workers', type=str) 23 | 24 | # model 25 | parser.add_argument('--model', type=str) 26 | parser.add_argument('--embedding_dim', type=int) 27 | parser.add_argument('--hidden_dim', type=int) 28 | 29 | # learning 30 | parser.add_argument('--lr', type=float) 31 | parser.add_argument('--lr_decay', type=float) 32 | parser.add_argument('--lr_decay_every', type=int, help='epoch') 33 | parser.add_argument('--n_epoch', type=int) 34 | parser.add_argument('--batch_size', type=int) 35 | parser.add_argument('--max_sent_len', type=int) 36 | 37 | # checkpoints 38 | parser.add_argument('--exp_name', type=str) 39 | parser.add_argument('--savedir', type=str) 40 | 41 | # sample 42 | parser.add_argument('--schedule_sample_method', type=str) 43 | parser.add_argument('--schedule_sample_prob', type=float) 44 | parser.add_argument('--schedule_sample_ratio', type=float) 45 | 46 | # config 47 | parser.add_argument('--cfg', type=str, default=None) 48 | 49 | args = parser.parse_args() 50 | 51 | # load config in yaml 52 | if args.cfg is not None: 53 | args = load_args(args, args.cfg) 54 | 55 | # check args 56 | assert args.schedule_sample_method in ['greedy', 'multinomial'] 57 | 58 | return args 59 | 60 | 61 | def get_eval_args(): 62 | parser = argparse.ArgumentParser(description='Video Captioning Arguments', argument_default=argparse.SUPPRESS) 63 | 64 | # data 65 | parser.add_argument('--dataset', type=str) 66 | parser.add_argument('--corpus', type=str) 67 | parser.add_argument('--reseco', type=str) 68 | parser.add_argument('--tag', type=str) 69 | parser.add_argument('--ref', type=str) 70 | parser.add_argument('--num_workers', type=str) 71 | 72 | # model 73 | parser.add_argument('--model', type=str) 74 | parser.add_argument('--embedding_dim', type=int) 75 | parser.add_argument('--hidden_dim', type=int) 76 | 77 | # learning 78 | parser.add_argument('--lr', type=float) 79 | parser.add_argument('--lr_decay', type=float,help='lr decay ratio') 80 | parser.add_argument('--lr_decay_every', type=int, help='epoch') 81 | parser.add_argument('--n_epoch', type=int) 82 | parser.add_argument('--batch_size', type=int) 83 | parser.add_argument('--max_sent_len', type=int) 84 | 85 | # checkpoints 86 | parser.add_argument('--exp_name', type=str) 87 | parser.add_argument('--savedir', type=str) 88 | 89 | # sample 90 | parser.add_argument('--schedule_sample_method', type=str) 91 | parser.add_argument('--schedule_sample_prob', type=float) 92 | parser.add_argument('--schedule_sample_ratio', type=float) 93 | 94 | args = parser.parse_args() 95 | 96 | # load config in yaml 97 | args = load_args(args, os.path.join(args.savedir, args.exp_name, 'configs.yaml')) 98 | 99 | if not hasattr(args, 'model_path'): 100 | setattr(args, 'model_path', os.path.join(args.savedir, args.exp_name, 'best.pth')) 101 | 102 | return args 103 | 104 | 105 | def save_args(args, path): 106 | d = vars(args) 107 | with open(os.path.join(path, 'configs.yaml'), 'w') as f: 108 | yaml.dump(d, f) 109 | 110 | 111 | def load_args(args, path): 112 | configs = yaml.safe_load(open(path, 'r')) 113 | # needed_args = data_keys + model_keys + checkpoint_keys 114 | for k, v in configs.items(): 115 | if not hasattr(args, k): 116 | setattr(args, k, v) 117 | return args 118 | 119 | 120 | if __name__ == '__main__': 121 | args = get_args() 122 | print('configs') 123 | print(args) 124 | save_args(args, './') 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | certifi==2016.2.28 4 | future==0.18.2 5 | gast==0.3.3 6 | google-pasta==0.2.0 7 | grpcio==1.29.0 8 | h5py==2.10.0 9 | importlib-metadata==1.6.0 10 | joblib==0.15.1 11 | Keras-Applications==1.0.8 12 | Keras-Preprocessing==1.1.2 13 | Markdown==3.2.2 14 | mock==4.0.2 15 | numpy==1.18.4 16 | protobuf==3.12.1 17 | PyYAML==5.4 18 | scikit-learn==0.23.1 19 | scipy==1.4.1 20 | six==1.15.0 21 | sklearn==0.0 22 | tensorboard==1.12.2 23 | tensorboardX==2.0 24 | termcolor==1.1.0 25 | threadpoolctl==2.0.0 26 | tqdm==4.46.0 27 | Werkzeug==1.0.1 28 | wrapt==1.12.1 29 | zipp==3.1.0 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from tqdm import tqdm 4 | import torch.nn as nn 5 | from torch.optim import Adam 6 | # from optim import Adam 7 | from dataloader import DataManager, collate_fn 8 | from torch.utils.data import DataLoader 9 | from functools import partial 10 | from pycocoevalcap.bleu.bleu import Bleu 11 | from pycocoevalcap.rouge.rouge import Rouge 12 | from pycocoevalcap.cider.cider import Cider 13 | from pycocoevalcap.meteor.meteor import Meteor 14 | from tensorboardX import SummaryWriter 15 | import os 16 | from models import init_model 17 | from utils import NLLLossWithLength 18 | from torch.optim.lr_scheduler import ExponentialLR 19 | from opts import get_args,save_args 20 | 21 | def train_epoch(epoch, model, dataloader, optimizer, loss_fn): 22 | model.train() 23 | running_loss = 0.0 24 | with tqdm(desc='Epoch {} - train'.format(epoch), unit='it', total=len(dataloader)) as pbar: 25 | for it, (feats, tags, captions) in enumerate(dataloader): 26 | feats = feats.cuda() 27 | tags = tags.cuda() 28 | captions = captions.cuda() 29 | 30 | optimizer.zero_grad() 31 | out = model(feats, tags, captions) 32 | loss = loss_fn(out, captions) 33 | # loss = loss_fn(out.view(-1, out.shape[-1]), captions.view(-1)) 34 | 35 | loss.backward() 36 | optimizer.step() 37 | 38 | running_loss += loss.item() 39 | pbar.set_postfix(loss=running_loss / (it + 1)) 40 | pbar.update() 41 | loss = running_loss / len(dataloader) 42 | return loss 43 | 44 | 45 | def evaluate(epoch, model, dataloader, dm, maxlen, split, verbose=False): 46 | model.eval() 47 | gen = {} 48 | gts = {} 49 | with tqdm(desc='Epoch {} - {}'.format(epoch, split), unit='it', total=len(dataloader)) as pbar: 50 | for it, (feats, tags, captions) in enumerate(dataloader): 51 | feats = feats.cuda() 52 | tags = tags.cuda() 53 | with torch.no_grad(): 54 | out = model.sample(feats, tags, maxlen=maxlen) 55 | decoded_sentences = dm.decode(out) 56 | for i in range(len(decoded_sentences)): 57 | gen['{}_{}'.format(it, i)] = [decoded_sentences[i]] 58 | gts['{}_{}'.format(it, i)] = captions[i] 59 | pbar.update() 60 | 61 | if verbose: 62 | for key in gts.keys(): 63 | print('gen:', gen[key][0]) 64 | 65 | return score(gts, gen) 66 | 67 | 68 | def score(ref, hypo): 69 | """ 70 | ref, dictionary of reference sentences (id, sentence) 71 | hypo, dictionary of hypothesis sentences (id, sentence) 72 | score, dictionary of scores 73 | """ 74 | # print('ref') 75 | # print(ref) 76 | # print('hypo') 77 | # print(hypo) 78 | scorers = [ 79 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 80 | (Meteor(), "METEOR"), 81 | (Rouge(), "ROUGE_L"), 82 | (Cider(), "CIDEr") 83 | ] 84 | final_scores = {} 85 | for scorer, method in scorers: 86 | score, scores = scorer.compute_score(ref, hypo) 87 | if type(score) == list: 88 | for m, s in zip(method, score): 89 | final_scores[m] = s 90 | else: 91 | final_scores[method] = score 92 | return final_scores 93 | 94 | 95 | if __name__ == '__main__': 96 | # get configs 97 | args = get_args() 98 | print(args) 99 | 100 | writer = SummaryWriter(log_dir=os.path.join(args.savedir, args.exp_name)) 101 | save_args(args,os.path.join(args.savedir,args.exp_name)) 102 | 103 | # load data 104 | dm = DataManager(args) 105 | 106 | # prepare model 107 | model = init_model(args, dm) 108 | model = model.cuda() 109 | 110 | # prepare training 111 | optimizer = Adam(lr=args.lr, params=model.parameters()) 112 | schedule = ExponentialLR(optimizer, args.lr_decay) 113 | loss_fn = NLLLossWithLength(ignore_index=dm.word2idx['']) 114 | # loss_fn = nn.NLLLoss(ignore_index=dm.word2idx['']) 115 | 116 | # split data 117 | train_data, val_data, test_data = dm.split() 118 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 119 | collate_fn=partial(collate_fn, split='train', padding_idx=dm.word2idx[''])) 120 | val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers, 121 | collate_fn=partial(collate_fn, split='val', padding_idx=dm.word2idx[''])) 122 | test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers, 123 | collate_fn=partial(collate_fn, split='test', padding_idx=dm.word2idx[''])) 124 | print('Start Video-Captioning Training') 125 | 126 | best_cider = 0 127 | best_epoch = -1 128 | best_score = None 129 | 130 | max_epoch = args.n_epoch 131 | for i in range(max_epoch): 132 | model.schedule_sample_prob = args.schedule_sample_prob + i * args.schedule_sample_ratio 133 | train_loss = train_epoch(i, model, train_loader, optimizer, loss_fn) 134 | writer.add_scalar('train/loss', train_loss, i) 135 | writer.add_scalar('train/schedule_sample_prob', model.schedule_sample_prob, i) 136 | writer.add_scalar('train/lr', optimizer.state_dict()['param_groups'][0]['lr'], i) 137 | print('Train loss:', train_loss) 138 | if (i + 1) % args.lr_decay_every == 0: 139 | schedule.step() 140 | 141 | # eval on val 142 | val_score = evaluate(i, model, val_loader, dm, maxlen=args.max_sent_len, split='Val') 143 | writer.add_scalar('val/Bleu1', val_score['Bleu_1'], i) 144 | writer.add_scalar('val/Bleu2', val_score['Bleu_2'], i) 145 | writer.add_scalar('val/Bleu3', val_score['Bleu_3'], i) 146 | writer.add_scalar('val/Bleu4', val_score['Bleu_4'], i) 147 | writer.add_scalar('val/RougeL', val_score['ROUGE_L'], i) 148 | writer.add_scalar('val/METEOR', val_score['METEOR'], i) 149 | writer.add_scalar('val/CIDEr', val_score['CIDEr'], i) 150 | print('Val score:') 151 | print(val_score) 152 | 153 | # eval on test 154 | test_score = evaluate(i, model, test_loader, dm, maxlen=args.max_sent_len, split='Test', verbose=True) 155 | writer.add_scalar('test/Bleu1', test_score['Bleu_1'], i) 156 | writer.add_scalar('test/Bleu2', test_score['Bleu_2'], i) 157 | writer.add_scalar('test/Bleu3', test_score['Bleu_3'], i) 158 | writer.add_scalar('test/Bleu4', test_score['Bleu_4'], i) 159 | writer.add_scalar('test/RougeL', test_score['ROUGE_L'], i) 160 | writer.add_scalar('test/METEOR', test_score['METEOR'], i) 161 | writer.add_scalar('test/CIDEr', test_score['CIDEr'], i) 162 | print('Test score:') 163 | print(test_score) 164 | 165 | if val_score['CIDEr'] > best_cider: 166 | best_cider = val_score['CIDEr'] 167 | best_score = test_score 168 | best_epoch = i 169 | torch.save(model.state_dict(), os.path.join(args.savedir, args.exp_name, 'best.pth')) 170 | 171 | torch.save(model.state_dict(), os.path.join(args.savedir, args.exp_name, 'last.pth')) 172 | 173 | writer.add_scalar('best/Bleu1', best_score['Bleu_1'], 1) 174 | writer.add_scalar('best/Bleu2', best_score['Bleu_2'], 1) 175 | writer.add_scalar('best/Bleu3', best_score['Bleu_3'], 1) 176 | writer.add_scalar('best/Bleu4', best_score['Bleu_4'], 1) 177 | writer.add_scalar('best/RougeL', best_score['ROUGE_L'], 1) 178 | writer.add_scalar('best/METEOR', best_score['METEOR'], 1) 179 | writer.add_scalar('best/CIDEr', best_score['CIDEr'], 1) 180 | writer.close() 181 | print('Best epoch', best_epoch) 182 | print('Best cider', best_cider) 183 | print('Best score', best_score) 184 | print('Done') 185 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NLLLossWithLength(nn.Module): 6 | def __init__(self, ignore_index, beta=0.7): 7 | super(NLLLossWithLength, self).__init__() 8 | self.ignore_index = ignore_index 9 | self.beta = beta 10 | self.nllloss = nn.NLLLoss(reduction='none', ignore_index=self.ignore_index) 11 | 12 | def forward(self, out, gt): 13 | bs = out.shape[0] 14 | max_len = out.shape[1] 15 | sentence_len = (gt != self.ignore_index).sum(dim=-1).view(bs,1) 16 | sentence_len = sentence_len.repeat(1,gt.shape[-1]) 17 | # print('init sentence len') 18 | # print(sentence_len) 19 | out = out.view(-1, out.shape[-1]) 20 | gt = gt.view(-1) 21 | sentence_len = sentence_len.view(-1) 22 | weight = 1 / sentence_len.float() ** self.beta 23 | # print('sentence_len') 24 | # print(sentence_len) 25 | loss = self.nllloss(out, gt) 26 | # print('unweighted loss') 27 | # print(loss) 28 | # print(loss.shape) 29 | # print('origin loss') 30 | # print(loss) 31 | # print('sentence_len') 32 | # print(sentence_len) 33 | # print('weight') 34 | # print(weight) 35 | # print(weight.shape) 36 | loss = loss * weight 37 | # print('weighted loss') 38 | # print(loss) 39 | # print('res loss') 40 | # print(loss) 41 | # print('sentence loss') 42 | # print('avg loss') 43 | # print(torch.sum(loss) / bs) 44 | # assert False 45 | return torch.sum(loss) / bs 46 | --------------------------------------------------------------------------------