├── .gitignore ├── LICENSE ├── README.md ├── caption ├── __init__.py ├── decoders │ ├── __init__.py │ ├── attention.py │ └── vanilla.py ├── encoders │ ├── __init__.py │ └── vanilla.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── captionbase.py │ └── vanilla.py ├── readers │ ├── __init__.py │ └── base.py └── utils │ ├── __init__.py │ └── inference.py ├── controlimcap ├── __init__.py ├── decoders │ ├── __init__.py │ ├── cfattention.py │ └── memory.py ├── driver │ ├── asg2caption.py │ ├── common.py │ └── configs │ │ ├── prepare_coco_imgsg_config.py │ │ └── prepare_vg_imgsg_config.py ├── encoders │ ├── flat.py │ └── gcn.py ├── models │ ├── flatattn.py │ ├── graphattn.py │ ├── graphflow.py │ └── graphmemory.py └── readers │ ├── __init__.py │ └── imgsgreader.py ├── figures ├── method_framework.png └── user_intention_examples.png └── framework ├── __init__.py ├── configbase.py ├── logbase.py ├── modelbase.py ├── modules ├── __init__.py ├── embeddings.py └── global_attention.py ├── ops.py └── run_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | eval_cap/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 cshizhe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Say As You Wish: Fine-grained Control of Image Caption Generation with Abstract Scene Graphs 2 | 3 | This repository contains PyTorch implementation of our paper [Say As You Wish: Fine-grained Control of Image Caption Generation with Abstract Scene Graphs (CVPR 2020)](https://arxiv.org/abs/2003.00387). 4 | 5 | ![Overview of ASG2Caption Model](figures/method_framework.png) 6 | 7 | ## Prerequisites 8 | Python 3 and PyTorch 1.3. 9 | 10 | ``` 11 | # clone the repository 12 | git clone https://github.com/cshizhe/asg2cap.git 13 | cd asg2cap 14 | # clone caption evaluation codes 15 | git clone https://github.com/cshizhe/eval_cap.git 16 | export PYTHONPATH=$(pwd):${PYTHONPATH} 17 | ``` 18 | 19 | ## Training & Inference 20 | ``` 21 | cd controlimcap/driver 22 | 23 | # support caption models: [node, node.role, 24 | # rgcn, rgcn.flow, rgcn.memory, rgcn.flow.memory] 25 | # see our paper for details 26 | mtype=rgcn.flow.memory 27 | 28 | # setup config files 29 | # you should modify data paths in configs/prepare_*_imgsg_config.py 30 | python configs/prepare_coco_imgsg_config.py $mtype 31 | resdir='' # copy the output string of the previous step 32 | 33 | # training 34 | python asg2caption.py $resdir/model.json $resdir/path.json $mtype --eval_loss --is_train --num_workers 8 35 | 36 | # inference 37 | python asg2caption.py $resdir/model.json $resdir/path.json $mtype --eval_set tst --num_workers 8 38 | ``` 39 | 40 | ## Datasets 41 | 42 | ### Annotations 43 | Annotations for MSCOCO and VisualGenome datasets can be download from [GoogleDrive](https://drive.google.com/open?id=1hzVhsxGQfA1ZILJ0RVkhcG57LepkjQEm). 44 | 45 | - (Image, ASG, Caption) annotations: regionfiles/image_id.json 46 | 47 | ``` 48 | JSON Format: 49 | { 50 | "region_id": { 51 | "objects":[ 52 | { 53 | "object_id": int, 54 | "name": str, 55 | "attributes": [str], 56 | "x": int, 57 | "y": int, 58 | "w": int, 59 | "h": int 60 | }], 61 | "relationships": [ 62 | { 63 | "relationship_id": int, 64 | "subject_id": int, 65 | "object_id": int, 66 | "name": str 67 | }], 68 | "phrase": str, 69 | } 70 | } 71 | ``` 72 | 73 | - vocabularies 74 | int2word.npy: [word] 75 | word2int.json: {word: int} 76 | 77 | - data splits: public_split directory 78 | trn_names.npy, val_names.npy, tst_names.npy 79 | 80 | ### Features 81 | Features for MSCOCO and VisualGenome datasets are available at [BaiduNetdisk](https://pan.baidu.com/s/1A1YS_ztPdIDz0ALgUo0qZg) (code: 6q32). 82 | 83 | We also provide pretrained models and codes to extract features for new images. 84 | 85 | - Global Image Feature: the last mean pooling feature of [ResNet101 pretrained on ImageNet](https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights) 86 | 87 | format: npy array, shape=(num_fts, dim_ft) 88 | corresponding to the order in data_split names 89 | 90 | - Region Image Feature: fc7 layer of [Faster-RCNN pretrained on VisualGenome](https://github.com/cshizhe/maskrcnn_benchmark) 91 | 92 | format: hdf5 files, "image_id".jpg.hdf5 93 | 94 | key: 'image_id'.jpg 95 | 96 | attrs: {"image_w": int, "image_h": int, "boxes": 4d array (x1, y1, x2, y2)} 97 | 98 | 99 | ## Result Visualization 100 | 101 | ![Examples](figures/user_intention_examples.png) 102 | 103 | 104 | ## Citations 105 | If you use this code as part of any published research, we'd really appreciate it if you could cite the following paper: 106 | ```text 107 | @article{chen2020say, 108 | title={Say As You Wish: Fine-grained Control of Image Caption Generation with Abstract Scene Graphs}, 109 | author={Chen, Shizhe and Jin, Qin and Wang, Peng and Wu, Qi}, 110 | journal={CVPR}, 111 | year={2020} 112 | } 113 | ``` 114 | 115 | 116 | ## License 117 | 118 | MIT License 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /caption/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/__init__.py -------------------------------------------------------------------------------- /caption/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/decoders/__init__.py -------------------------------------------------------------------------------- /caption/decoders/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import caption.utils.inference 6 | 7 | import caption.decoders.vanilla 8 | from framework.modules.embeddings import Embedding 9 | from framework.modules.global_attention import GlobalAttention 10 | from framework.modules.global_attention import AdaptiveAttention 11 | 12 | class AttnDecoderConfig(caption.decoders.vanilla.DecoderConfig): 13 | def __init__(self): 14 | super().__init__() 15 | self.memory_same_key_value = True 16 | self.attn_input_size = 512 17 | self.attn_size = 512 18 | self.attn_type = 'mlp' # mlp, dot, general 19 | 20 | def _assert(self): 21 | assert self.attn_type in ['dot', 'general', 'mlp'], ('Please select a valid attention type.') 22 | 23 | class AttnDecoder(caption.decoders.vanilla.Decoder): 24 | def __init__(self, config): 25 | super().__init__(config) 26 | 27 | self.attn = GlobalAttention(self.config.hidden_size, self.config.attn_size, self.config.attn_type) 28 | if self.config.attn_type == 'mlp': 29 | self.attn_linear_context = nn.Linear(self.config.attn_input_size, 30 | self.config.attn_size, bias=False) 31 | 32 | if not self.config.memory_same_key_value: 33 | self.memory_value_layer = nn.Linear(self.config.attn_input_size, 34 | self.config.attn_size, bias=True) 35 | 36 | @property 37 | def rnn_input_size(self): 38 | if self.config.memory_same_key_value: 39 | return self.config.dim_word + self.config.attn_input_size 40 | else: 41 | return self.config.dim_word + self.config.attn_size 42 | 43 | def gen_memory_key_value(self, enc_memories): 44 | if self.config.memory_same_key_value: 45 | memory_values = enc_memories 46 | else: 47 | memory_values = F.relu(self.memory_value_layer(enc_memories)) 48 | 49 | if self.config.attn_type == 'mlp': 50 | memory_keys = self.attn_linear_context(enc_memories) 51 | else: 52 | memory_keys = enc_memories 53 | 54 | return memory_keys, memory_values 55 | 56 | def forward(self, inputs, enc_states, enc_memories, enc_masks, return_attn=False): 57 | ''' 58 | Args: 59 | inputs: (batch, dec_seq_len) 60 | enc_states: (batch, dim_embed) 61 | enc_memoris: (batch, enc_seq_len, dim_embed) 62 | enc_masks: (batch, enc_seq_len) 63 | Returns: 64 | logits: (batch*seq_len, num_words) 65 | ''' 66 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 67 | states = self.init_dec_state(enc_states) 68 | outs = states[0][-1] if isinstance(states, tuple) else states[-1] 69 | 70 | step_outs, step_attns = [], [] 71 | for t in range(inputs.size(1)): 72 | wordids = inputs[:, t] 73 | if t > 0 and self.config.schedule_sampling: 74 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 75 | sample_mask = sample_rate < self.config.ss_rate 76 | prob = self.softmax(step_outs[-1]).detach() 77 | sampled_wordids = torch.multinomial(prob, 1).view(-1) 78 | wordids.masked_scatter_(sample_mask, sampled_wordids) 79 | embed = self.embedding(wordids) 80 | attn_score, attn_memory = self.attn(outs, 81 | memory_keys, memory_values, enc_masks) 82 | step_attns.append(attn_score) 83 | rnn_input = torch.cat([embed, attn_memory], 1).unsqueeze(1) 84 | rnn_input = self.dropout(rnn_input) 85 | outs, states = self.rnn(rnn_input, states) 86 | outs = outs[:, 0] 87 | logit = self.calc_logits_with_rnn_outs(outs) 88 | step_outs.append(logit) 89 | 90 | logits = torch.stack(step_outs, 1) 91 | logits = logits.view(-1, self.config.num_words) 92 | 93 | if return_attn: 94 | return logits, step_attns 95 | return logits 96 | 97 | def step_fn(self, words, step, **kwargs): 98 | ''' 99 | Args: 100 | words: (batch, 1) 101 | kwargs: 102 | - states: decoder init states (num_layers, batch, hidden_size) 103 | - outs: last decoder layer hidden as attn query (batch, hidden_size) 104 | - memory_keys: (batch, enc_seq_len, key_size) 105 | - memory_values: (batch, enc_seq_len, value_size) 106 | - memory_masks: (batch, enc_seq_len) 107 | ''' 108 | states = kwargs['states'] 109 | outs = kwargs['outs'] 110 | memory_keys = kwargs['memory_keys'] 111 | memory_values = kwargs['memory_values'] 112 | memory_masks = kwargs['memory_masks'] 113 | 114 | embeds = self.embedding(words) 115 | 116 | attn_score, attn_memory = self.attn( 117 | outs, memory_keys, memory_values, memory_masks) 118 | 119 | attn_memory = attn_memory.unsqueeze(1) 120 | rnn_inputs = torch.cat([embeds, attn_memory], 2) 121 | outs, states = self.rnn(rnn_inputs, states) 122 | outs = outs[:, 0] 123 | logits = self.calc_logits_with_rnn_outs(outs) 124 | logprobs = self.log_softmax(logits) 125 | 126 | kwargs['states'] = states 127 | kwargs['outs'] = outs 128 | return logprobs, kwargs 129 | 130 | def expand_fn(self, beam_width, **kwargs): 131 | ''' 132 | Args: 133 | kwargs: 134 | - states 135 | - outs: (batch, hidden_size) 136 | - memory_keys, memory_values, memory_masks: (batch, ...) 137 | ''' 138 | kwargs = super().expand_fn(beam_width, **kwargs) 139 | for key, value in kwargs.items(): 140 | if key != 'states': 141 | value_size = list(value.size()) 142 | expand_size = [value_size[0], beam_width] + value_size[1:] 143 | final_size = [value_size[0] * beam_width] + value_size[1:] 144 | kwargs[key] = value.unsqueeze(1).expand(*expand_size).contiguous() \ 145 | .view(*final_size) 146 | return kwargs 147 | 148 | def select_fn(self, idxs, **kwargs): 149 | '''Select examples according to idxs 150 | kwargs: 151 | - states: lstm tuple (num_layer, batch_size*beam_width, hidden_size) 152 | - outs, memory_keys, memory_values, memory_masks: (batch, ...) 153 | ''' 154 | kwargs = super().select_fn(idxs, **kwargs) 155 | for key, value in kwargs.items(): 156 | if key != 'states': 157 | kwargs[key] = torch.index_select(value, 0, idxs) 158 | return kwargs 159 | 160 | def sample_decode(self, words, enc_states, enc_memories, enc_masks, greedy=True, early_stop=True): 161 | ''' 162 | Args 163 | words: (batch, ) 164 | enc_states: (batch, hidden_size) 165 | enc_memories: (batch, enc_seq_len, attn_input_size) 166 | enc_masks: (batch, enc_seq_len) 167 | ''' 168 | states = self.init_dec_state(enc_states) 169 | outs = states[0][-1] if isinstance(states, tuple) else states[-1] 170 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 171 | 172 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 173 | words, self.step_fn, self.config.max_words_in_sent, 174 | greedy=greedy, early_stop=early_stop, states=states, outs=outs, 175 | memory_keys=memory_keys, memory_values=memory_values, memory_masks=enc_masks) 176 | 177 | return seq_words, seq_word_logprobs 178 | 179 | def beam_search_decode(self, words, enc_states, enc_memories, enc_masks): 180 | states = self.init_dec_state(enc_states) 181 | outs = states[0][-1] if isinstance(states, tuple) else states[-1] 182 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 183 | 184 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 185 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 186 | sent_pool_size=self.config.sent_pool_size, 187 | expand_fn=self.expand_fn, select_fn=self.select_fn, 188 | states=states, outs=outs, memory_keys=memory_keys, 189 | memory_values=memory_values, memory_masks=enc_masks) 190 | return sent_pool 191 | 192 | 193 | class BUTDAttnDecoder(AttnDecoder): 194 | ''' 195 | Requires: dim input visual feature == lstm hidden size 196 | ''' 197 | def __init__(self, config): 198 | nn.Module.__init__(self) # need to rewrite RNN 199 | self.config = config 200 | # word embedding 201 | self.embedding = Embedding(self.config.num_words, 202 | self.config.dim_word, fix_word_embed=self.config.fix_word_embed) 203 | # rnn params (attn_lstm and lang_lstm) 204 | self.attn_lstm = nn.LSTMCell( 205 | self.config.hidden_size + self.config.attn_input_size + self.config.dim_word, # (h_lang, v_g, w) 206 | self.config.hidden_size, bias=True) 207 | memory_size = self.config.attn_input_size if self.config.memory_same_key_value else self.config.attn_size 208 | self.lang_lstm = nn.LSTMCell( 209 | self.config.hidden_size + memory_size, # (h_attn, v_a) 210 | self.config.hidden_size, bias=True) 211 | # attentions 212 | self.attn = GlobalAttention(self.config.hidden_size, self.config.attn_size, self.config.attn_type) 213 | if self.config.attn_type == 'mlp': 214 | self.attn_linear_context = nn.Linear(self.config.attn_input_size, self.config.attn_size, bias=False) 215 | if not self.config.memory_same_key_value: 216 | self.memory_value_layer = nn.Linear(self.config.attn_input_size, self.config.attn_size, bias=True) 217 | # outputs 218 | if self.config.hidden2word: 219 | self.hidden2word = nn.Linear(self.config.hidden_size, self.config.dim_word) 220 | output_size = self.config.dim_word 221 | else: 222 | output_size = self.config.hidden_size 223 | if not self.config.tie_embed: 224 | self.fc = nn.Linear(output_size, self.config.num_words) 225 | self.log_softmax = nn.LogSoftmax(dim=1) 226 | self.softmax = nn.Softmax(dim=1) 227 | 228 | self.dropout = nn.Dropout(self.config.dropout) 229 | self.init_rnn_weights(self.attn_lstm, 'lstm', num_layers=1) 230 | self.init_rnn_weights(self.lang_lstm, 'lstm', num_layers=1) 231 | 232 | def init_dec_state(self, batch_size): 233 | param = next(self.parameters()) 234 | states = [] 235 | for i in range(2): # (hidden, cell) 236 | states.append(torch.zeros((2, batch_size, self.config.hidden_size), 237 | dtype=torch.float32).to(param.device)) 238 | return states 239 | 240 | def forward(self, inputs, enc_globals, enc_memories, enc_masks, return_attn=False): 241 | ''' 242 | Args: 243 | inputs: (batch, dec_seq_len) 244 | enc_globals: (batch, hidden_size) 245 | enc_memories: (batch, enc_seq_len, attn_input_size) 246 | enc_masks: (batch, enc_seq_len) 247 | Returns: 248 | logits: (batch*seq_len, num_words) 249 | ''' 250 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 251 | states = self.init_dec_state(inputs.size(0)) # zero init state 252 | 253 | step_outs, step_attns = [], [] 254 | for t in range(inputs.size(1)): 255 | wordids = inputs[:, t] 256 | if t > 0 and self.config.schedule_sampling: 257 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 258 | sample_mask = sample_rate < self.config.ss_rate 259 | prob = self.softmax(step_outs[-1]).detach() # detach grad 260 | sampled_wordids = torch.multinomial(prob, 1).view(-1) 261 | wordids.masked_scatter_(sample_mask, sampled_wordids) 262 | embed = self.embedding(wordids) 263 | 264 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 265 | torch.cat([states[0][1], enc_globals, embed], dim=1), 266 | (states[0][0], states[1][0])) 267 | 268 | attn_score, attn_memory = self.attn(h_attn_lstm, 269 | memory_keys, memory_values, enc_masks) 270 | step_attns.append(attn_score) 271 | 272 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 273 | torch.cat([h_attn_lstm, attn_memory], dim=1), 274 | (states[0][1], states[1][1])) 275 | 276 | outs = h_lang_lstm 277 | logit = self.calc_logits_with_rnn_outs(outs) 278 | step_outs.append(logit) 279 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 280 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 281 | 282 | logits = torch.stack(step_outs, 1) 283 | logits = logits.view(-1, self.config.num_words) 284 | 285 | if return_attn: 286 | return logits, step_attns 287 | return logits 288 | 289 | def step_fn(self, words, step, **kwargs): 290 | states = kwargs['states'] 291 | enc_globals = kwargs['enc_globals'] 292 | memory_keys = kwargs['memory_keys'] 293 | memory_values = kwargs['memory_values'] 294 | memory_masks = kwargs['memory_masks'] 295 | 296 | embed = self.embedding(words.squeeze(1)) 297 | 298 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 299 | torch.cat([states[0][1], enc_globals, embed], dim=1), 300 | (states[0][0], states[1][0])) 301 | 302 | attn_score, attn_memory = self.attn(h_attn_lstm, 303 | memory_keys, memory_values, memory_masks) 304 | 305 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 306 | torch.cat([h_attn_lstm, attn_memory], dim=1), 307 | (states[0][1], states[1][1])) 308 | 309 | logits = self.calc_logits_with_rnn_outs(h_lang_lstm) 310 | logprobs = self.log_softmax(logits) 311 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 312 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 313 | 314 | kwargs['states'] = states 315 | return logprobs, kwargs 316 | 317 | def sample_decode(self, words, enc_globals, enc_memories, enc_masks, greedy=True): 318 | states = self.init_dec_state(words.size(0)) 319 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 320 | 321 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 322 | words, self.step_fn, self.config.max_words_in_sent, 323 | greedy=greedy, states=states, enc_globals=enc_globals, memory_keys=memory_keys, 324 | memory_values=memory_values, memory_masks=enc_masks) 325 | 326 | return seq_words, seq_word_logprobs 327 | 328 | def beam_search_decode(self, words, enc_globals, enc_memories, enc_masks): 329 | states = self.init_dec_state(words.size(0)) 330 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 331 | 332 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 333 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 334 | sent_pool_size=self.config.sent_pool_size, expand_fn=self.expand_fn, 335 | select_fn=self.select_fn, states=states, enc_globals=enc_globals, 336 | memory_keys=memory_keys, memory_values=memory_values, memory_masks=enc_masks) 337 | 338 | return sent_pool 339 | -------------------------------------------------------------------------------- /caption/decoders/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import framework.configbase 6 | from framework.modules.embeddings import Embedding 7 | import framework.ops 8 | import caption.utils.inference 9 | 10 | class DecoderConfig(framework.configbase.ModuleConfig): 11 | def __init__(self): 12 | super().__init__() 13 | self.rnn_type = 'lstm' 14 | self.num_words = 0 15 | self.dim_word = 512 16 | self.hidden_size = 512 17 | self.num_layers = 1 18 | self.hidden2word = False 19 | self.tie_embed = False 20 | self.fix_word_embed = False 21 | self.max_words_in_sent = 20 22 | self.dropout = 0.5 23 | self.schedule_sampling = False 24 | self.ss_rate = 0.05 25 | self.ss_max_rate = 0.25 26 | self.ss_increase_rate = 0.05 27 | self.ss_increase_epoch = 5 28 | 29 | self.greedy_or_beam = False # test method 30 | self.beam_width = 1 31 | self.sent_pool_size = 1 32 | 33 | def _assert(self): 34 | if self.tie_embed and not self.hidden2word: 35 | assert self.dim_word == self.hidden_size 36 | 37 | class Decoder(nn.Module): 38 | def __init__(self, config): 39 | super().__init__() 40 | self.config = config 41 | 42 | self.embedding = Embedding(self.config.num_words, 43 | self.config.dim_word, fix_word_embed=self.config.fix_word_embed) 44 | 45 | kwargs = {} 46 | self.rnn = framework.ops.rnn_factory(self.config.rnn_type, 47 | input_size=self.rnn_input_size, hidden_size=self.config.hidden_size, 48 | num_layers=self.config.num_layers, dropout=self.config.dropout, 49 | bias=True, batch_first=True, **kwargs) 50 | 51 | if self.config.hidden2word: 52 | self.hidden2word = nn.Linear(self.config.hidden_size, self.config.dim_word) 53 | output_size = self.config.dim_word 54 | else: 55 | output_size = self.config.hidden_size 56 | 57 | if not self.config.tie_embed: 58 | self.fc = nn.Linear(output_size, self.config.num_words) 59 | 60 | self.log_softmax = nn.LogSoftmax(dim=1) 61 | self.softmax = nn.Softmax(dim=1) 62 | 63 | self.dropout = nn.Dropout(self.config.dropout) 64 | 65 | self.init_rnn_weights(self.rnn, self.config.rnn_type) 66 | 67 | @property 68 | def rnn_input_size(self): 69 | return self.config.dim_word 70 | 71 | def init_rnn_weights(self, rnn, rnn_type, num_layers=None): 72 | if rnn_type == 'lstm': 73 | # the ordering of weights a biases is ingate, forgetgate, cellgate, outgate 74 | # init forgetgate as 1 to make rnn remember the past in the beginning 75 | if num_layers is None: 76 | num_layers = rnn.num_layers 77 | for layer in range(num_layers): 78 | for name in ['i', 'h']: 79 | try: 80 | weight = getattr(rnn, 'weight_%sh_l%d'%(name, layer)) 81 | except: 82 | weight = getattr(rnn, 'weight_%sh'%name) 83 | nn.init.orthogonal_(weight.data) 84 | try: 85 | bias = getattr(rnn, 'bias_%sh_l%d'%(name, layer)) 86 | except: 87 | bias = getattr(rnn, 'bias_%sh'%name) # BUTD: LSTM Cell 88 | nn.init.constant_(bias, 0) 89 | if name == 'i': 90 | bias.data.index_fill_(0, torch.arange( 91 | rnn.hidden_size, rnn.hidden_size*2).long(), 1) 92 | # bias.requires_grad = False 93 | 94 | def init_dec_state(self, encoder_state): 95 | ''' 96 | The encoder hidden is (batch, dim_embed) 97 | We need to convert it to (layers, batch, hidden_size) 98 | assert dim_embed == hidden_size 99 | ''' 100 | decoder_state = encoder_state.repeat(self.config.num_layers, 1, 1) 101 | if self.config.rnn_type == 'lstm' or self.config.rnn_type == 'ONLSTM': 102 | decoder_state = tuple([decoder_state, decoder_state]) 103 | return decoder_state 104 | 105 | def calc_logits_with_rnn_outs(self, outs): 106 | ''' 107 | Args: 108 | outs: (batch, hidden_size) 109 | Returns: 110 | logits: (batch, num_words) 111 | ''' 112 | if self.config.hidden2word: 113 | outs = torch.tanh(self.hidden2word(outs)) 114 | outs = self.dropout(outs) 115 | if self.config.tie_embed: 116 | logits = torch.mm(outs, self.embedding.we.weight.t()) 117 | else: 118 | logits = self.fc(outs) 119 | return logits 120 | 121 | def forward(self, inputs, encoder_state): 122 | ''' 123 | Args: 124 | inputs: (batch, seq_len) 125 | encoder_state: (batch, dim_embed) 126 | Returns: 127 | logits: (batch*seq_len, num_words) 128 | ''' 129 | states = self.init_dec_state(encoder_state) 130 | 131 | if self.config.schedule_sampling: 132 | step_outs = [] 133 | for t in range(inputs.size(1)): 134 | wordids = inputs[:, t] 135 | if t > 0: 136 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 137 | sample_mask = sample_rate < self.config.ss_rate 138 | prob = self.softmax(step_outs[-1]).detach() 139 | sampled_wordids = torch.multinomial(prob, 1).squeeze(1) 140 | wordids.masked_scatter_(sample_mask, sampled_wordids) 141 | embed = self.embedding(wordids) 142 | embed = self.dropout(embed) 143 | outs, states = self.rnn(embed.unsqueeze(1), states) 144 | outs = outs[:, 0] 145 | logit = self.calc_logits_with_rnn_outs(outs) 146 | step_outs.append(logit) 147 | logits = torch.stack(step_outs, 1) 148 | logits = logits.view(-1, self.config.num_words) 149 | # pytorch rnn utilzes cudnn to speed up 150 | else: 151 | embeds = self.embedding(inputs) 152 | embeds = self.dropout(embeds) 153 | # outs.size(batch, seq_len, hidden_size) 154 | outs, states = self.rnn(embeds, states) 155 | outs = outs.contiguous().view(-1, self.config.hidden_size) 156 | logits = self.calc_logits_with_rnn_outs(outs) 157 | return logits 158 | 159 | def step_fn(self, words, step, **kwargs): 160 | ''' 161 | Args: 162 | words: (batch_size, 1) 163 | step: int (start from 0) 164 | kwargs: 165 | states: decoder rnn states (num_layers, batch, hidden_size) 166 | Returns: 167 | logprobs: (batch, num_words) 168 | kwargs: dict, {'states'} 169 | ''' 170 | embeds = self.embedding(words) 171 | outs, states = self.rnn(embeds, kwargs['states']) 172 | outs = outs[:, 0] 173 | logits = self.calc_logits_with_rnn_outs(outs) 174 | logprobs = self.log_softmax(logits) 175 | kwargs['states'] = states 176 | return logprobs, kwargs 177 | 178 | def expand_fn(self, beam_width, **kwargs): 179 | '''After the first step of beam search, expand the examples to beam_width times 180 | e.g. (1, 2, 3) -> (1, 1, 2, 2, 3, 3) 181 | beam_width: int 182 | kwargs: 183 | - states: lstm tuple (num_layer, batch_size, hidden_size) 184 | ''' 185 | states = kwargs['states'] 186 | is_tuple = isinstance(states, tuple) 187 | if not is_tuple: 188 | states = (states, ) 189 | 190 | expanded_states = [] 191 | for h in states: 192 | num_layer, batch_size, hidden_size = h.size() 193 | eh = h.unsqueeze(2).expand(-1, -1, beam_width, -1).contiguous() \ 194 | .view(num_layer, batch_size * beam_width, hidden_size) 195 | expanded_states.append(eh) 196 | 197 | if is_tuple: 198 | states = tuple(expanded_states) 199 | else: 200 | states = expanded_states[0] 201 | 202 | kwargs['states'] = states 203 | return kwargs 204 | 205 | def select_fn(self, idxs, **kwargs): 206 | '''Select examples according to idxs 207 | kwargs: 208 | states: lstm tuple (num_layer, batch_size*beam_width, hidden_size) 209 | ''' 210 | states = kwargs['states'] 211 | if isinstance(states, tuple): 212 | states = tuple([torch.index_select(h, 1, idxs) for h in states]) 213 | else: 214 | states = torch.index_select(h, 1, idxs) 215 | kwargs['states'] = states 216 | return kwargs 217 | 218 | def sample_decode(self, words, enc_states, greedy=True, early_stop=True): 219 | ''' 220 | Args 221 | words: (batch, ) 222 | enc_states: (batch, hidden_size) 223 | ''' 224 | states = self.init_dec_state(enc_states) 225 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 226 | words, self.step_fn, self.config.max_words_in_sent, 227 | greedy=greedy, early_stop=early_stop, states=states) 228 | 229 | return seq_words, seq_word_logprobs 230 | 231 | def beam_search_decode(self, words, enc_states): 232 | ''' 233 | Args: 234 | words: (batch, ) 235 | enc_states: (batch, hidden_size) 236 | Returns: 237 | sent_pool: list, len=batch 238 | item=list, len=beam_width, 239 | element=(sent_logprob, words, word_logprobs) 240 | ''' 241 | states = self.init_dec_state(enc_states) 242 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 243 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 244 | sent_pool_size=self.config.sent_pool_size, 245 | expand_fn=self.expand_fn, select_fn=self.select_fn, 246 | states=states) 247 | return sent_pool 248 | 249 | -------------------------------------------------------------------------------- /caption/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/encoders/__init__.py -------------------------------------------------------------------------------- /caption/encoders/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import framework.configbase 6 | import framework.ops 7 | 8 | 9 | ''' 10 | Vanilla Encoder: embed nd array (batch_size, ..., dim_ft) 11 | - EncoderConfig 12 | - Encoder 13 | 14 | Multilayer Perceptrons: feed forward networks + softmax 15 | - MLPConfig 16 | - MLP 17 | ''' 18 | 19 | class EncoderConfig(framework.configbase.ModuleConfig): 20 | def __init__(self): 21 | super().__init__() 22 | self.dim_fts = [2048] 23 | self.dim_embed = 512 24 | self.is_embed = True 25 | self.dropout = 0 26 | self.norm = False 27 | self.nonlinear = False 28 | 29 | def _assert(self): 30 | if not self.is_embed: 31 | assert self.dim_embed == sum(self.dim_fts) 32 | 33 | class Encoder(nn.Module): 34 | def __init__(self, config): 35 | super().__init__() 36 | self.config = config 37 | 38 | if self.config.is_embed: 39 | self.ft_embed = nn.Linear(sum(self.config.dim_fts), self.config.dim_embed) 40 | self.dropout = nn.Dropout(self.config.dropout) 41 | 42 | def forward(self, fts): 43 | ''' 44 | Args: 45 | fts: size=(batch, ..., sum(dim_fts)) 46 | Returns: 47 | embeds: size=(batch, dim_embed) 48 | ''' 49 | embeds = fts 50 | if self.config.is_embed: 51 | embeds = self.ft_embed(embeds) 52 | if self.config.nonlinear: 53 | embeds = F.relu(embeds) 54 | if self.config.norm: 55 | embeds = framework.ops.l2norm(embeds) 56 | embeds = self.dropout(embeds) 57 | return embeds 58 | 59 | -------------------------------------------------------------------------------- /caption/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/models/__init__.py -------------------------------------------------------------------------------- /caption/models/attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import framework.configbase 10 | import caption.encoders.vanilla 11 | import caption.decoders.attention 12 | import caption.models.captionbase 13 | 14 | MPENCODER = 'mp_encoder' 15 | ATTNENCODER = 'attn_encoder' 16 | DECODER = 'decoder' 17 | 18 | class AttnModelConfig(framework.configbase.ModelConfig): 19 | def __init__(self): 20 | super().__init__() 21 | self.subcfgs[MPENCODER] = caption.encoders.vanilla.EncoderConfig() 22 | self.subcfgs[ATTNENCODER] = caption.encoders.vanilla.EncoderConfig() 23 | self.subcfgs[DECODER] = caption.decoders.attention.AttnDecoderConfig() 24 | 25 | def _assert(self): 26 | assert self.subcfgs[MPENCODER].dim_embed == self.subcfgs[DECODER].hidden_size 27 | assert self.subcfgs[ATTNENCODER].dim_embed == self.subcfgs[DECODER].attn_input_size 28 | 29 | 30 | class AttnModel(caption.models.captionbase.CaptionModelBase): 31 | def build_submods(self): 32 | submods = {} 33 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 34 | submods[ATTNENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[ATTNENCODER]) 35 | submods[DECODER] = caption.decoders.attention.AttnDecoder(self.config.subcfgs[DECODER]) 36 | return submods 37 | 38 | def prepare_input_batch(self, batch_data, is_train=False): 39 | outs = {} 40 | outs['mp_fts'] = torch.FloatTensor(batch_data['mp_fts']).to(self.device) 41 | outs['attn_fts'] = torch.FloatTensor(batch_data['attn_fts']).to(self.device) 42 | outs['attn_masks'] = torch.FloatTensor(batch_data['attn_masks'].astype(np.float32)).to(self.device) 43 | 44 | if is_train: 45 | outs['caption_ids'] = torch.LongTensor(batch_data['caption_ids']).to(self.device) 46 | outs['caption_masks'] = torch.FloatTensor(batch_data['caption_masks'].astype(np.float32)).to(self.device) 47 | return outs 48 | 49 | def forward_encoder(self, input_batch): 50 | encoder_state = self.submods[MPENCODER](input_batch['mp_fts']) 51 | encoder_outputs = self.submods[ATTNENCODER](input_batch['attn_fts']) 52 | return {'init_states': encoder_state, 'attn_fts': encoder_outputs} 53 | 54 | def forward_loss(self, batch_data, step=None): 55 | input_batch = self.prepare_input_batch(batch_data, is_train=True) 56 | 57 | enc_outs = self.forward_encoder(input_batch) 58 | # logits.shape=(batch*seq_len, num_words) 59 | logits = self.submods[DECODER](input_batch['caption_ids'][:, :-1], 60 | enc_outs['init_states'], enc_outs['attn_fts'], input_batch['attn_masks']) 61 | loss = self.criterion(logits, input_batch['caption_ids'], 62 | input_batch['caption_masks']) 63 | 64 | return loss 65 | 66 | def validate_batch(self, batch_data, addition_outs=None): 67 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 68 | enc_outs = self.forward_encoder(input_batch) 69 | init_words = torch.zeros(input_batch['attn_masks'].size(0), dtype=torch.int64).to(self.device) 70 | 71 | pred_sent, _ = self.submods[DECODER].sample_decode(init_words, 72 | enc_outs['init_states'], enc_outs['attn_fts'], input_batch['attn_masks'], greedy=True) 73 | return pred_sent 74 | 75 | def test_batch(self, batch_data, greedy_or_beam): 76 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 77 | enc_outs = self.forward_encoder(input_batch) 78 | init_words = torch.zeros(input_batch['attn_masks'].size(0), dtype=torch.int64).to(self.device) 79 | 80 | if greedy_or_beam: 81 | sent_pool = self.submods[DECODER].beam_search_decode( 82 | init_words, enc_outs['init_states'], enc_outs['attn_fts'], 83 | input_batch['attn_masks']) 84 | pred_sent = [pool[0][1] for pool in sent_pool] 85 | else: 86 | pred_sent, word_logprobs = self.submods[DECODER].sample_decode( 87 | init_words, enc_outs['init_states'], enc_outs['attn_fts'], 88 | input_batch['attn_masks'], greedy=True) 89 | sent_pool = [] 90 | for sent, word_logprob in zip(pred_sent, word_logprobs): 91 | sent_pool.append([(word_logprob.sum().item(), sent, word_logprob)]) 92 | 93 | return pred_sent, sent_pool 94 | 95 | class BUTDAttnModel(AttnModel): 96 | def build_submods(self): 97 | submods = {} 98 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 99 | submods[ATTNENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[ATTNENCODER]) 100 | submods[DECODER] = caption.decoders.attention.BUTDAttnDecoder(self.config.subcfgs[DECODER]) 101 | return submods 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /caption/models/captionbase.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import collections 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from eval_cap.bleu.bleu import Bleu 9 | from eval_cap.cider.cider import Cider 10 | from eval_cap.meteor.meteor import Meteor 11 | from eval_cap.rouge.rouge import Rouge 12 | 13 | import caption.utils.inference 14 | import framework.modelbase 15 | 16 | DECODER = 'decoder' 17 | 18 | class CaptionLoss(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.loss = nn.CrossEntropyLoss(reduction='none') 22 | 23 | def forward(self, logits, caption_ids, caption_masks, reduce_mean=True): 24 | ''' 25 | logits: shape=(batch*(seq_len-1), num_words) 26 | caption_ids: shape=(batch, seq_len) 27 | caption_masks: shape=(batch, seq_len) 28 | ''' 29 | batch_size, seq_len = caption_ids.size() 30 | losses = self.loss(logits, caption_ids[:, 1:].contiguous().view(-1)) 31 | onehot_caption_masks = caption_masks[:, 1:] > 0 32 | onehot_caption_masks = onehot_caption_masks.float() 33 | caption_masks = caption_masks[:, 1:].reshape(-1).float() 34 | if reduce_mean: 35 | loss = torch.sum(losses * caption_masks) / torch.sum(onehot_caption_masks) 36 | else: 37 | loss = torch.div( 38 | torch.sum((losses * caption_masks).view(batch_size, seq_len-1), 1), 39 | torch.sum(onehot_caption_masks, 1)) 40 | return loss 41 | 42 | class CaptionModelBase(framework.modelbase.ModelBase): 43 | def __init__(self, config, _logger=None, eval_loss=False, int2word_file=None, gpu_id=0): 44 | self.eval_loss = eval_loss 45 | self.scorers = { 46 | 'bleu4': Bleu(4), 47 | 'cider': Cider(), 48 | } 49 | if int2word_file is not None: 50 | self.int2sent = caption.utils.inference.IntToSentence(int2word_file) 51 | super().__init__(config, _logger=_logger, gpu_id=gpu_id) 52 | 53 | def build_loss(self): 54 | criterion = CaptionLoss() 55 | return criterion 56 | 57 | def validate(self, val_reader, step=None): 58 | self.eval_start() 59 | 60 | # current eval_loss only select one caption for each image/video 61 | if self.eval_loss: 62 | avg_loss, n_batches = 0, 0 63 | 64 | pred_sents, ref_sents = {}, {} 65 | # load dataset once 66 | for batch_data in val_reader: 67 | if self.eval_loss: 68 | loss = self.forward_loss(batch_data) 69 | avg_loss += loss.data.item() 70 | n_batches += 1 71 | pred_sent = self.validate_batch(batch_data) 72 | pred_sent = pred_sent.data.cpu().numpy() 73 | for i, name in enumerate(batch_data['names']): 74 | pred_sents[name] = [self.int2sent(pred_sent[i])] 75 | ref_sents[name] = batch_data['ref_sents'][name] 76 | 77 | if self.eval_loss: 78 | avg_loss /= n_batches 79 | 80 | # compute translation score (bleu, rouge) 81 | metrics = collections.OrderedDict() 82 | if self.eval_loss: 83 | metrics['loss'] = avg_loss 84 | for measure, scorer in self.scorers.items(): 85 | score, _ = scorer.compute_score(ref_sents, pred_sents) 86 | if measure == 'bleu4': 87 | score = score[-1] 88 | # bleu4 is the "mean" of 1-4 gram (np.exp(np.mean(np.log(actual_scores)))) 89 | # which is the same as nltk.translate.bleu_score.corpus_bleu() 90 | metrics[measure] = score * 100 91 | return metrics 92 | 93 | def test(self, tst_reader, tst_pred_file, tst_model_file=None, outcap_format=0): 94 | if tst_model_file is not None: 95 | self.load_checkpoint(tst_model_file) 96 | self.eval_start() 97 | 98 | pred_sents = {} 99 | for batch_data in tst_reader: 100 | greedy_or_beam = self.config.subcfgs[DECODER].greedy_or_beam 101 | pred_sent, sent_pool = self.test_batch(batch_data, greedy_or_beam) 102 | for i, name in enumerate(batch_data['names']): 103 | if isinstance(name, tuple): 104 | name = '_'.join([str(x) for x in name]) 105 | pred_sents[name] = self.gen_out_caption_format( 106 | sent_pool[i], self.int2sent, outcap_format) 107 | 108 | output_dir = os.path.dirname(tst_pred_file) 109 | if not os.path.exists(output_dir): 110 | os.makedirs(output_dir) 111 | with open(tst_pred_file, 'w') as f: 112 | json.dump(pred_sents, f, indent=2) 113 | 114 | def gen_out_caption_format(self, sent_pool, int2sent, format=0): 115 | ''' 116 | Args: 117 | sent_pool: list, [[topk_prob, topk_sent, word_probs], ...], sorted 118 | format: 119 | 0: [top1_sent] 120 | 1: [top1_sent, prob, word_probs] 121 | 2: [[topk_sent, prob], ...] 122 | 3: [[topk_sent, prob, word_probs], ...] 123 | ''' 124 | if format == 0: 125 | return [int2sent(sent_pool[0][1])] 126 | elif format == 1: 127 | sent = int2sent(sent_pool[0][1]) 128 | return [sent, sent_pool[0][0].item(), [p.item() for p in sent_pool[0][2]]] 129 | elif format == 2: 130 | outs = [] 131 | for item in sent_pool: 132 | if len(item) == 3: 133 | sent_prob, sent_ids, word_probs = item 134 | outs.append([int2sent(sent_ids), sent_prob.item()]) 135 | return outs 136 | elif format == 3: 137 | outs = [] 138 | for item in sent_pool: 139 | if len(item) == 3: 140 | sent_prob, sent_ids, word_probs = item 141 | outs.append([int2sent(sent_ids), sent_prob.item(), [p.item() for p in word_probs]]) 142 | return outs 143 | 144 | def epoch_postprocess(self, epoch): 145 | super().epoch_postprocess(epoch) 146 | 147 | if DECODER in self.config.subcfgs: 148 | dec_cfg = self.config.subcfgs[DECODER] 149 | if dec_cfg.schedule_sampling and dec_cfg.ss_rate < dec_cfg.ss_max_rate: 150 | if (epoch+1) % dec_cfg.ss_increase_epoch == 0: 151 | dec_cfg.ss_rate = dec_cfg.ss_rate + dec_cfg.ss_increase_rate 152 | self.print_fn('schedule sampling rate %.4f'%(dec_cfg.ss_rate)) 153 | 154 | ################################ DIY ################################ 155 | def validate_batch(self, batch_data, addition_outs=None): 156 | ''' 157 | Returns: 158 | pred_sent: list of int_sent 159 | ''' 160 | raise NotImplementedError('implement validate_batch function') 161 | 162 | def test_batch(self, batch_data, greedy_or_beam): 163 | ''' 164 | Returns: 165 | pred_sent: list of int_sent 166 | sent_pool: 167 | ''' 168 | raise NotImplementedError 169 | 170 | def prepare_input_batch(self, batch_data, is_train=False): 171 | ''' 172 | Return: dict of tensors 173 | ''' 174 | raise NotImplementedError 175 | 176 | def forward_encoder(self, input_batch): 177 | ''' 178 | Return: dict of encoder outputs 179 | ''' 180 | raise NotImplementedError 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /caption/models/vanilla.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import collections 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import framework.configbase 11 | import framework.ops 12 | 13 | import caption.encoders.vanilla 14 | import caption.decoders.vanilla 15 | import caption.models.captionbase 16 | 17 | ENCODER = 'encoder' 18 | DECODER = 'decoder' 19 | 20 | class ModelConfig(framework.configbase.ModelConfig): 21 | def __init__(self): 22 | super().__init__() 23 | self.subcfgs[ENCODER] = caption.encoders.vanilla.EncoderConfig() 24 | self.subcfgs[DECODER] = caption.decoders.vanilla.DecoderConfig() 25 | 26 | def _assert(self): 27 | assert self.subcfgs[ENCODER].dim_embed == self.subcfgs[DECODER].hidden_size 28 | 29 | 30 | class VanillaModel(caption.models.captionbase.CaptionModelBase): 31 | def build_submods(self): 32 | submods = {} 33 | submods[ENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[ENCODER]) 34 | submods[DECODER] = caption.decoders.vanilla.Decoder(self.config.subcfgs[DECODER]) 35 | return submods 36 | 37 | def prepare_input_batch(self, batch_data, is_train=False): 38 | outs = { 39 | 'mp_fts': torch.FloatTensor(batch_data['mp_fts']).to(self.device), 40 | } 41 | if is_train: 42 | outs['caption_ids'] = torch.LongTensor(batch_data['caption_ids']).to(self.device) 43 | outs['caption_masks'] = torch.FloatTensor(batch_data['caption_masks'].astype(np.float32)).to(self.device) 44 | return outs 45 | 46 | def forward_encoder(self, input_batch): 47 | ft_embeds = self.submods[ENCODER](input_batch['mp_fts']) 48 | return {'init_states': ft_embeds} 49 | 50 | def forward_loss(self, batch_data, step=None): 51 | input_batch = self.prepare_input_batch(batch_data, is_train=True) 52 | enc_outs = self.forward_encoder(input_batch) 53 | # logits.shape=(batch*(seq_len-1), num_words) 54 | logits = self.submods[DECODER](input_batch['caption_ids'][:, :-1], enc_outs['init_states']) 55 | loss = self.criterion(logits, input_batch['caption_ids'], input_batch['caption_masks']) 56 | return loss 57 | 58 | def validate_batch(self, batch_data): 59 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 60 | enc_outs = self.forward_encoder(input_batch) 61 | 62 | batch_size = len(batch_data['mp_fts']) 63 | init_words = torch.zeros(batch_size, dtype=torch.int64).to(self.device) 64 | pred_sent, _ = self.submods[DECODER].sample_decode( 65 | init_words, enc_outs['init_states'], greedy=True) 66 | return pred_sent 67 | 68 | def test_batch(self, batch_data, greedy_or_beam): 69 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 70 | enc_outs = self.forward_encoder(input_batch) 71 | 72 | batch_size = len(batch_data['mp_fts']) 73 | init_words = torch.zeros(batch_size, dtype=torch.int64).to(self.device) 74 | if greedy_or_beam: 75 | sent_pool = self.submods[DECODER].beam_search_decode( 76 | init_words, enc_outs['init_states']) 77 | pred_sent = [pool[0][1] for pool in sent_pool] 78 | else: 79 | pred_sent, word_logprobs = self.submods[DECODER].sample_decode( 80 | init_words, enc_outs['init_states'], greedy=True) 81 | sent_pool = [] 82 | for sent, word_logprob in zip(pred_sent, word_logprobs): 83 | sent_pool.append([(word_logprob.sum().item(), sent, word_logprob)]) 84 | return pred_sent, sent_pool 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /caption/readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/readers/__init__.py -------------------------------------------------------------------------------- /caption/readers/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import codecs 5 | 6 | import torch.utils.data 7 | 8 | from caption.utils.inference import BOS, EOS, UNK 9 | 10 | class CaptionDatasetBase(torch.utils.data.Dataset): 11 | def __init__(self, word2int_file, ref_caption_file=None, 12 | max_words_in_sent=20, is_train=False, return_label=False, _logger=None): 13 | if _logger is None: 14 | self.print_fn = print 15 | else: 16 | self.print_fn = _logger.info 17 | 18 | if word2int_file.endswith('json'): 19 | self.word2int = json.load(open(word2int_file)) 20 | else: 21 | self.word2int = np.load(word2int_file) 22 | self.int2word = {i: w for w, i in self.word2int.items()} 23 | 24 | if ref_caption_file is not None: 25 | self.ref_captions = json.load(open(ref_caption_file)) 26 | 27 | self.max_words_in_sent = max_words_in_sent 28 | self.is_train = is_train 29 | self.return_label = return_label 30 | 31 | def sent2int(self, str_sent): 32 | int_sent = [self.word2int.get(w, UNK) for w in str_sent.split()] 33 | return int_sent 34 | 35 | def pad_sents(self, int_sent, add_bos_eos=True): 36 | if add_bos_eos: 37 | sent = [BOS] + int_sent + [EOS] 38 | else: 39 | sent = int_sent 40 | sent = sent[:self.max_words_in_sent] 41 | num_pad = self.max_words_in_sent - len(sent) 42 | mask = [True]*len(sent) + [False] * num_pad 43 | sent = sent + [EOS] * num_pad 44 | return sent, mask 45 | 46 | def pad_or_trim_feature(self, attn_ft, max_len, average=False): 47 | seq_len, dim_ft = attn_ft.shape 48 | mask = np.zeros((max_len, ), np.bool) 49 | 50 | # pad 51 | if seq_len < max_len: 52 | new_ft = np.zeros((max_len, dim_ft), np.float32) 53 | new_ft[:seq_len] = attn_ft 54 | mask[:seq_len] = True 55 | elif seq_len == max_len: 56 | new_ft = attn_ft 57 | mask[:] = True 58 | # trim 59 | else: 60 | if average: 61 | idxs = np.round(np.linspace(0, seq_len, max_len+1)).astype(np.int32) 62 | new_ft = np.array([np.mean(attn_ft[idxs[i]: idxs[i+1]], axis=0) for i in range(max_len)]) 63 | else: 64 | idxs = np.round(np.linspace(0, seq_len-1, max_len)).astype(np.int32) 65 | new_ft = attn_ft[idxs] 66 | mask[:] = True 67 | return new_ft, mask 68 | 69 | 70 | -------------------------------------------------------------------------------- /caption/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/caption/utils/__init__.py -------------------------------------------------------------------------------- /caption/utils/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | BOS = 0 8 | EOS = 1 9 | UNK = 2 10 | 11 | class IntToSentence(object): 12 | def __init__(self, int2word_file): 13 | self.int2word = np.load(int2word_file) 14 | 15 | def __call__(self, int_sent): 16 | str_sent = [] 17 | for x in int_sent: 18 | if x == EOS: 19 | break 20 | str_sent.append(self.int2word[x]) 21 | return ' '.join(str_sent) 22 | 23 | 24 | def sample_decode(words, step_fn, max_words_in_sent, 25 | greedy=False, sample_topk=0, early_stop=True, **kwargs): 26 | ''' 27 | Args: 28 | words: init words, shape=(batch, ) 29 | step_fn: function return word logprobs 30 | max_words_in_sent: int, max decoded sentence length 31 | greedy: greedy or multinomial sampling 32 | sample_topk: each step sample from topk words instead of all words 33 | early_stop: stop if all examples are ended 34 | kwargs for RNN decoders: 35 | - states: init decoder states, shape=(num_layers, batch, hidden_size) 36 | - outs: the last hidden layer (query in attn), shape=(batch, hidden_size*bi) 37 | - memory_keys: (optional for attn) 38 | - memory_values: (optional for attn) 39 | - memory_masks: (optional for attn) 40 | kwargs for Transformer decoders: 41 | - 42 | 43 | Returns: 44 | seq_words: int sent, LongTensor, shape=(batch, dec_seq_len) 45 | seq_word_logprobs: logprobs of the selected word, shape=(batch, dec_seq_len) 46 | ''' 47 | seq_words, seq_word_logprobs = [], [] 48 | 49 | words = torch.unsqueeze(words, 1) 50 | unfinished = torch.ones_like(words).byte() 51 | 52 | for t in range(max_words_in_sent): 53 | logprobs, kwargs = step_fn(words, t, **kwargs) 54 | if greedy: 55 | _, words = torch.topk(logprobs, 1) 56 | else: 57 | probs = torch.exp(logprobs) 58 | if sample_topk > 0: 59 | topk_probs, topk_words = torch.topk(probs, sample_topk) 60 | idxs = torch.multinomial(topk_probs, 1) 61 | words = torch.gather(topk_words, 1, idxs) 62 | else: 63 | words = torch.multinomial(probs, 1) 64 | # words.shape=(batch, 1) 65 | seq_words.append(words) 66 | # logprobs.shape=(batch, num_words) 67 | logprobs = torch.gather(logprobs, 1, words) 68 | seq_word_logprobs.append(logprobs) 69 | unfinished = unfinished * (words != EOS) 70 | if early_stop and unfinished.sum().data.item() == 0: 71 | break 72 | seq_words = torch.cat(seq_words, 1).data 73 | seq_word_logprobs = torch.cat(seq_word_logprobs, 1) 74 | 75 | return seq_words, seq_word_logprobs 76 | 77 | 78 | def beam_search_decode(words, step_fn, max_words_in_sent, 79 | beam_width=5, sent_pool_size=5, expand_fn=None, select_fn=None, **kwargs): 80 | ''' 81 | Inputs are the same as sample_decode 82 | ''' 83 | k = beam_width 84 | batch_size = words.size(0) 85 | # store the best sentences 86 | sent_pool = [[] for i in range(batch_size)] 87 | # remained beams for each input 88 | batch_sent_pool_remain_cnt = np.zeros((batch_size, )) + sent_pool_size 89 | # store selected words in every step to recover path 90 | step_words = [] 91 | step_word_logprobs = [] 92 | # store previous indexs of selected words 93 | step_prevs = [] 94 | # sum of log probs of current sents for each beams 95 | cum_logprob = None # Tensor 96 | 97 | # row_idxs = [[0, ..., 0], [k, ..., k], ..., [(batch-1)*k, ..., (batch-1)*k] 98 | row_idxs = torch.arange(0, batch_size*k, k).unsqueeze(1).repeat(1, k) 99 | row_idxs = row_idxs.long().view(-1).to(words.device) 100 | 101 | for t in range(max_words_in_sent): 102 | words = words.unsqueeze(1) 103 | 104 | # logprobs.shape=(batch, num_words) 105 | logprobs, kwargs = step_fn(words, t, **kwargs) 106 | 107 | if t == 0: 108 | topk_logprobs, topk_words = torch.topk(logprobs, k) 109 | # update 110 | words = topk_words.view(-1) 111 | logprobs = topk_logprobs.view(-1) 112 | step_words.append(words) 113 | step_word_logprobs.append(logprobs) 114 | step_prevs.append([]) 115 | cum_logprob = logprobs 116 | if len(kwargs) > 0: 117 | kwargs = expand_fn(k, **kwargs) 118 | 119 | else: 120 | topk2_logprobs, topk2_words = torch.topk(logprobs, k) 121 | tmp_cum_logprob = topk2_logprobs + cum_logprob.unsqueeze(1) 122 | tmp_cum_logprob = tmp_cum_logprob.view(batch_size, k*k) 123 | topk2_words = topk2_words.view(batch_size, k*k) 124 | topk_cum_logprobs, topk_argwords = torch.topk(tmp_cum_logprob, k) 125 | topk_words = torch.gather(topk2_words, 1, topk_argwords) 126 | # update 127 | words = topk_words.view(-1) 128 | step_words.append(words) 129 | step_word_logprobs.append(torch.gather( 130 | topk2_logprobs.view(batch_size, k*k), 1, topk_argwords).view(-1)) 131 | cum_logprob = topk_cum_logprobs.view(-1) 132 | 133 | # select previous hidden 134 | # prev_idxs.size = (batch, k) 135 | prev_idxs = topk_argwords.div(k).long().view(-1) + row_idxs 136 | step_prevs.append(prev_idxs) 137 | kwargs = select_fn(prev_idxs, **kwargs) 138 | finished_idxs = (words == EOS) 139 | 140 | for i, finished in enumerate(finished_idxs): 141 | b = i // k 142 | if batch_sent_pool_remain_cnt[b] > 0: 143 | if finished or t == max_words_in_sent - 1: 144 | batch_sent_pool_remain_cnt[b] -= 1 145 | 146 | cmpl_sent, cmpl_word_logprobs = beam_search_recover_one_caption( 147 | step_words, step_prevs, step_word_logprobs, 148 | t, i, beam_width=beam_width) 149 | sent_logprob = cum_logprob[i]/(t+1) 150 | 151 | sent_pool[b].append((sent_logprob, cmpl_sent, cmpl_word_logprobs)) 152 | 153 | cum_logprob.masked_fill_(finished_idxs, -1000000) # stop select the beam 154 | if np.sum(batch_sent_pool_remain_cnt) <=0: 155 | break 156 | 157 | for i, sents in enumerate(sent_pool): 158 | sents.sort(key=lambda x: -x[0]) 159 | return sent_pool 160 | 161 | def beam_search_recover_one_caption(step_words, step_prevs, 162 | step_word_logprobs, timestep, ith, beam_width=5): 163 | """ 164 | step_words: list, len=seq_len, item.shape=(batch*beam_width,) 165 | step_prevs: list, len=seq_len, item.shape=(batch*beam_width,) 166 | step_word_logprobs: list, len=seq_len, item.shape=(batch*beam_width,) 167 | timestep: the timestep item in step_* 168 | ith: the last idx of wordids 169 | """ 170 | caption, caption_logprob = [], [] 171 | for t in range(timestep, 0, -1): 172 | caption.append(step_words[t][ith]) 173 | caption_logprob.append(step_word_logprobs[t][ith]) 174 | ith = step_prevs[t][ith] 175 | 176 | caption.append(step_words[0][ith]) 177 | caption_logprob.append(step_word_logprobs[0][ith]) 178 | caption.reverse() 179 | caption_logprob.reverse() 180 | 181 | return caption, caption_logprob 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /controlimcap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/controlimcap/__init__.py -------------------------------------------------------------------------------- /controlimcap/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/controlimcap/decoders/__init__.py -------------------------------------------------------------------------------- /controlimcap/decoders/cfattention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import caption.utils.inference 6 | 7 | import caption.decoders.vanilla 8 | from framework.modules.embeddings import Embedding 9 | from framework.modules.global_attention import GlobalAttention 10 | 11 | 12 | class ContentFlowAttentionDecoder(caption.decoders.attention.BUTDAttnDecoder): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | 16 | memory_size = self.config.attn_size if self.config.memory_same_key_value else self.config.attn_input_size 17 | self.address_layer = nn.Sequential( 18 | nn.Linear(self.config.hidden_size + memory_size, memory_size), 19 | nn.ReLU(), 20 | nn.Linear(memory_size, 1 + 3)) 21 | 22 | def forward(self, inputs, enc_globals, enc_memories, enc_masks, flow_edges, return_attn=False): 23 | ''' 24 | Args: 25 | inputs: (batch, dec_seq_len) 26 | enc_globals: (batch, hidden_size) 27 | enc_memories: (batch, enc_seq_len, attn_input_size) 28 | enc_masks: (batch, enc_seq_len) 29 | Returns: 30 | logits: (batch*seq_len, num_words) 31 | ''' 32 | batch_size, max_attn_len = enc_masks.size() 33 | device = inputs.device 34 | 35 | states = self.init_dec_state(batch_size) # zero init state 36 | 37 | # initialize content attention 38 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 39 | 40 | # initialize location attention score: (batch, max_attn_len) 41 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 42 | prev_attn_score[:, 0] = 1 43 | 44 | step_outs, step_attns = [], [] 45 | for t in range(inputs.size(1)): 46 | wordids = inputs[:, t] 47 | if t > 0 and self.config.schedule_sampling: 48 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 49 | sample_mask = sample_rate < self.config.ss_rate 50 | prob = self.softmax(step_outs[-1]).detach() # detach grad 51 | sampled_wordids = torch.multinomial(prob, 1).view(-1) 52 | wordids.masked_scatter_(sample_mask, sampled_wordids) 53 | embed = self.embedding(wordids) 54 | 55 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 56 | torch.cat([states[0][1], enc_globals, embed], dim=1), 57 | (states[0][0], states[1][0])) 58 | 59 | prev_memory = torch.sum(prev_attn_score.unsqueeze(2) * memory_values, 1) 60 | address_params = self.address_layer(torch.cat([h_attn_lstm, prev_memory], 1)) 61 | interpolate_gate = torch.sigmoid(address_params[:, :1]) 62 | flow_gate = torch.softmax(address_params[:, 1:], dim=1) 63 | 64 | # content_attn_score: (batch, max_attn_len) 65 | content_attn_score, content_attn_memory = self.attn(h_attn_lstm, 66 | memory_keys, memory_values, enc_masks) 67 | 68 | # location attention flow: (batch, max_attn_len) 69 | flow_attn_score_1 = torch.einsum('bts,bs->bt', flow_edges, prev_attn_score) 70 | flow_attn_score_2 = torch.einsum('bts,bs->bt',flow_edges, flow_attn_score_1) 71 | # (batch, max_attn_len, 3) 72 | flow_attn_score = torch.stack([x.view(batch_size, max_attn_len) \ 73 | for x in [prev_attn_score, flow_attn_score_1, flow_attn_score_2]], 2) 74 | flow_attn_score = torch.sum(flow_gate.unsqueeze(1) * flow_attn_score, 2) 75 | 76 | # content + location interpolation 77 | attn_score = interpolate_gate * content_attn_score + (1 - interpolate_gate) * flow_attn_score 78 | 79 | # final attention 80 | step_attns.append(attn_score) 81 | prev_attn_score = attn_score 82 | attn_memory = torch.sum(attn_score.unsqueeze(2) * memory_values, 1) 83 | 84 | # next layer with attended context 85 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 86 | torch.cat([h_attn_lstm, attn_memory], dim=1), 87 | (states[0][1], states[1][1])) 88 | 89 | outs = h_lang_lstm 90 | logit = self.calc_logits_with_rnn_outs(outs) 91 | step_outs.append(logit) 92 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 93 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 94 | 95 | logits = torch.stack(step_outs, 1) 96 | logits = logits.view(-1, self.config.num_words) 97 | 98 | if return_attn: 99 | return logits, step_attns 100 | return logits 101 | 102 | def step_fn(self, words, step, **kwargs): 103 | states = kwargs['states'] 104 | enc_globals = kwargs['enc_globals'] 105 | memory_keys = kwargs['memory_keys'] 106 | memory_values = kwargs['memory_values'] 107 | memory_masks = kwargs['memory_masks'] 108 | prev_attn_score = kwargs['prev_attn_score'] 109 | flow_edges = kwargs['flow_edges'] 110 | 111 | batch_size, max_attn_len = memory_masks.size() 112 | 113 | embed = self.embedding(words.squeeze(1)) 114 | 115 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 116 | torch.cat([states[0][1], enc_globals, embed], dim=1), 117 | (states[0][0], states[1][0])) 118 | 119 | prev_memory = torch.sum(prev_attn_score.unsqueeze(2) * memory_values, 1) 120 | address_params = self.address_layer(torch.cat([h_attn_lstm, prev_memory], 1)) 121 | interpolate_gate = torch.sigmoid(address_params[:, :1]) 122 | flow_gate = torch.softmax(address_params[:, 1:], dim=1) 123 | 124 | # content_attn_score: (batch, max_attn_len) 125 | content_attn_score, content_attn_memory = self.attn(h_attn_lstm, 126 | memory_keys, memory_values, memory_masks) 127 | 128 | # location attention flow: (batch, max_attn_len) 129 | flow_attn_score_1 = torch.einsum('bts,bs->bt', flow_edges, prev_attn_score) 130 | flow_attn_score_2 = torch.einsum('bts,bs->bt', flow_edges, flow_attn_score_1) 131 | flow_attn_score = torch.stack([x.view(batch_size, max_attn_len) \ 132 | for x in [prev_attn_score, flow_attn_score_1, flow_attn_score_2]], 2) 133 | flow_attn_score = torch.sum(flow_gate.unsqueeze(1) * flow_attn_score, 2) 134 | 135 | # content + location interpolation 136 | attn_score = interpolate_gate * content_attn_score + (1 - interpolate_gate) * flow_attn_score 137 | 138 | # final attention 139 | attn_memory = torch.sum(attn_score.unsqueeze(2) * memory_values, 1) 140 | 141 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 142 | torch.cat([h_attn_lstm, attn_memory], dim=1), 143 | (states[0][1], states[1][1])) 144 | 145 | logits = self.calc_logits_with_rnn_outs(h_lang_lstm) 146 | logprobs = self.log_softmax(logits) 147 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 148 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 149 | 150 | kwargs['prev_attn_score'] = attn_score 151 | kwargs['states'] = states 152 | return logprobs, kwargs 153 | 154 | def sample_decode(self, words, enc_globals, enc_memories, enc_masks, flow_edges, greedy=True): 155 | '''Args: 156 | words: (batch, ) 157 | enc_globals: (batch, hidden_size) 158 | enc_memories: (batch, enc_seq_len, attn_input_size) 159 | enc_masks: (batch, enc_seq_len) 160 | flow_edges: sparse matrix, (batch*max_attn_len, batch*max_attn_len) 161 | ''' 162 | batch_size, max_attn_len = enc_masks.size() 163 | device = enc_masks.device 164 | 165 | states = self.init_dec_state(batch_size) 166 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 167 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 168 | prev_attn_score[:, 0] = 1 169 | 170 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 171 | words, self.step_fn, self.config.max_words_in_sent, 172 | greedy=greedy, states=states, enc_globals=enc_globals, 173 | memory_keys=memory_keys, memory_values=memory_values, memory_masks=enc_masks, 174 | prev_attn_score=prev_attn_score, flow_edges=flow_edges) 175 | 176 | return seq_words, seq_word_logprobs 177 | 178 | def beam_search_decode(self, words, enc_globals, enc_memories, enc_masks, flow_edges): 179 | batch_size, max_attn_len = enc_masks.size() 180 | device = enc_masks.device 181 | 182 | states = self.init_dec_state(batch_size) 183 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 184 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 185 | prev_attn_score[:, 0] = 1 186 | 187 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 188 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 189 | sent_pool_size=self.config.sent_pool_size, 190 | expand_fn=self.expand_fn, select_fn=self.select_fn, 191 | memory_keys=memory_keys, memory_values=memory_values, memory_masks=enc_masks, 192 | states=states, enc_globals=enc_globals, 193 | prev_attn_score=prev_attn_score, flow_edges=flow_edges) 194 | 195 | return sent_pool 196 | -------------------------------------------------------------------------------- /controlimcap/decoders/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import caption.decoders.attention 6 | from framework.modules.embeddings import Embedding 7 | from framework.modules.global_attention import GlobalAttention 8 | 9 | from controlimcap.decoders.cfattention import ContentFlowAttentionDecoder 10 | 11 | 12 | class MemoryDecoder(caption.decoders.attention.BUTDAttnDecoder): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | 16 | memory_size = self.config.attn_size if self.config.memory_same_key_value else self.config.attn_input_size 17 | self.memory_update_layer = nn.Sequential( 18 | nn.Linear(self.config.hidden_size + memory_size, memory_size), 19 | nn.ReLU(), 20 | nn.Linear(memory_size, memory_size * 2)) 21 | self.sentinal_layer = nn.Sequential( 22 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 23 | nn.ReLU(), 24 | nn.Linear(self.config.hidden_size, 1)) 25 | 26 | def forward(self, inputs, enc_globals, enc_memories, enc_masks, return_attn=False): 27 | ''' 28 | Args: 29 | inputs: (batch, dec_seq_len) 30 | enc_globals: (batch, hidden_size) 31 | enc_memories: (batch, enc_seq_len, attn_input_size) 32 | enc_masks: (batch, enc_seq_len) 33 | Returns: 34 | logits: (batch*seq_len, num_words) 35 | ''' 36 | enc_seq_len = enc_memories.size(1) 37 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 38 | states = self.init_dec_state(inputs.size(0)) # zero init state 39 | 40 | step_outs, step_attns = [], [] 41 | for t in range(inputs.size(1)): 42 | wordids = inputs[:, t] 43 | if t > 0 and self.config.schedule_sampling: 44 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 45 | sample_mask = sample_rate < self.config.ss_rate 46 | prob = self.softmax(step_outs[-1]).detach() # detach grad 47 | sampled_wordids = torch.multinomial(prob, 1).view(-1) 48 | wordids.masked_scatter_(sample_mask, sampled_wordids) 49 | embed = self.embedding(wordids) 50 | 51 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 52 | torch.cat([states[0][1], enc_globals, embed], dim=1), 53 | (states[0][0], states[1][0])) 54 | 55 | # attn_score: (batch, max_attn_len) 56 | attn_score, attn_memory = self.attn(h_attn_lstm, 57 | memory_keys, memory_values, enc_masks) 58 | step_attns.append(attn_score) 59 | 60 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 61 | torch.cat([h_attn_lstm, attn_memory], dim=1), 62 | (states[0][1], states[1][1])) 63 | 64 | # write: update memory keys and values 65 | # (batch, enc_seq_len, hidden_size + attn_input_size) 66 | individual_vectors = torch.cat( 67 | [h_lang_lstm.unsqueeze(1).expand(-1, enc_seq_len, -1), enc_memories], 2) 68 | update_vectors = self.memory_update_layer(individual_vectors) 69 | memory_size = update_vectors.size(-1) // 2 70 | erase_gates = torch.sigmoid(update_vectors[:, :, :memory_size]) 71 | add_vectors = update_vectors[:, :, memory_size:] 72 | 73 | # some words do not need to attend on visual nodes 74 | sentinal_gates = torch.sigmoid(self.sentinal_layer(h_lang_lstm)) 75 | memory_attn_score = attn_score * sentinal_gates 76 | 77 | enc_memories = enc_memories * (1 - memory_attn_score.unsqueeze(2) * erase_gates) \ 78 | + memory_attn_score.unsqueeze(2) * add_vectors 79 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 80 | 81 | outs = h_lang_lstm 82 | logit = self.calc_logits_with_rnn_outs(outs) 83 | step_outs.append(logit) 84 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 85 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 86 | 87 | logits = torch.stack(step_outs, 1) 88 | logits = logits.view(-1, self.config.num_words) 89 | 90 | if return_attn: 91 | return logits, step_attns 92 | return logits 93 | 94 | def step_fn(self, words, step, **kwargs): 95 | states = kwargs['states'] 96 | enc_globals = kwargs['enc_globals'] 97 | enc_memories = kwargs['enc_memories'] 98 | memory_masks = kwargs['memory_masks'] 99 | enc_seq_len = enc_memories.size(1) 100 | 101 | embed = self.embedding(words.squeeze(1)) 102 | 103 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 104 | torch.cat([states[0][1], enc_globals, embed], dim=1), 105 | (states[0][0], states[1][0])) 106 | 107 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 108 | attn_score, attn_memory = self.attn(h_attn_lstm, 109 | memory_keys, memory_values, memory_masks) 110 | 111 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 112 | torch.cat([h_attn_lstm, attn_memory], dim=1), 113 | (states[0][1], states[1][1])) 114 | 115 | logits = self.calc_logits_with_rnn_outs(h_lang_lstm) 116 | logprobs = self.log_softmax(logits) 117 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 118 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 119 | 120 | # write: update memory keys and values 121 | individual_vectors = torch.cat([h_lang_lstm.unsqueeze(1).expand(-1, enc_seq_len, -1), enc_memories], 2) 122 | update_vectors = self.memory_update_layer(individual_vectors) 123 | memory_size = update_vectors.size(-1) // 2 124 | erase_gates = torch.sigmoid(update_vectors[:, :, :memory_size]) 125 | add_vectors = update_vectors[:, :, memory_size:] 126 | 127 | sentinal_gates = torch.sigmoid(self.sentinal_layer(h_lang_lstm)) 128 | memory_attn_score = attn_score * sentinal_gates 129 | enc_memories = enc_memories * (1 - memory_attn_score.unsqueeze(2) * erase_gates) \ 130 | + memory_attn_score.unsqueeze(2) * add_vectors 131 | kwargs['enc_memories'] = enc_memories 132 | kwargs['states'] = states 133 | return logprobs, kwargs 134 | 135 | def sample_decode(self, words, enc_globals, enc_memories, enc_masks, greedy=True, early_stop=True): 136 | ''' 137 | Args 138 | words: (batch, ) 139 | enc_states: (batch, hidden_size) 140 | enc_memories: (batch, enc_seq_len, attn_input_size) 141 | enc_masks: (batch, enc_seq_len) 142 | ''' 143 | states = self.init_dec_state(words.size(0)) 144 | 145 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 146 | words, self.step_fn, self.config.max_words_in_sent, 147 | greedy=greedy, early_stop=early_stop, states=states, 148 | enc_globals=enc_globals, enc_memories=enc_memories, memory_masks=enc_masks) 149 | 150 | return seq_words, seq_word_logprobs 151 | 152 | def beam_search_decode(self, words, enc_globals, enc_memories, enc_masks): 153 | states = self.init_dec_state(words.size(0)) 154 | 155 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 156 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 157 | sent_pool_size=self.config.sent_pool_size, 158 | expand_fn=self.expand_fn, select_fn=self.select_fn, 159 | states=states, enc_globals=enc_globals, 160 | enc_memories=enc_memories, memory_masks=enc_masks) 161 | return sent_pool 162 | 163 | 164 | class MemoryFlowDecoder(ContentFlowAttentionDecoder): 165 | def __init__(self, config): 166 | super().__init__(config) 167 | 168 | memory_size = self.config.attn_size if self.config.memory_same_key_value else self.config.attn_input_size 169 | self.memory_update_layer = nn.Sequential( 170 | nn.Linear(self.config.hidden_size + memory_size, memory_size), 171 | nn.ReLU(), 172 | nn.Linear(memory_size, memory_size * 2)) 173 | self.sentinal_layer = nn.Sequential( 174 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 175 | nn.ReLU(), 176 | nn.Linear(self.config.hidden_size, 1)) 177 | 178 | def forward(self, inputs, enc_globals, enc_memories, enc_masks, flow_edges, return_attn=False): 179 | ''' 180 | Args: 181 | inputs: (batch, dec_seq_len) 182 | enc_globals: (batch, hidden_size) 183 | enc_memories: (batch, enc_seq_len, attn_input_size) 184 | enc_masks: (batch, enc_seq_len) 185 | flow_edges: sparse matrix (num_nodes, num_nodes), num_nodes=batch*enc_seq_len 186 | Returns: 187 | logits: (batch*seq_len, num_words) 188 | ''' 189 | batch_size, max_attn_len = enc_masks.size() 190 | device = inputs.device 191 | 192 | # initialize states 193 | states = self.init_dec_state(batch_size) # zero init state 194 | 195 | # location attention: (batch, max_attn_len) 196 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 197 | prev_attn_score[:, 0] = 1 198 | 199 | step_outs, step_attns = [], [] 200 | for t in range(inputs.size(1)): 201 | wordids = inputs[:, t] 202 | if t > 0 and self.config.schedule_sampling: 203 | sample_rate = torch.rand(wordids.size(0)).to(wordids.device) 204 | sample_mask = sample_rate < self.config.ss_rate 205 | prob = self.softmax(step_outs[-1]).detach() # detach grad 206 | sampled_wordids = torch.multinomial(prob, 1).view(-1) 207 | wordids.masked_scatter_(sample_mask, sampled_wordids) 208 | embed = self.embedding(wordids) 209 | 210 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 211 | torch.cat([states[0][1], enc_globals, embed], dim=1), 212 | (states[0][0], states[1][0])) 213 | 214 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 215 | 216 | prev_attn_memory = torch.sum(prev_attn_score.unsqueeze(2) * memory_values, 1) 217 | address_params = self.address_layer(torch.cat([h_attn_lstm, prev_attn_memory], 1)) 218 | interpolate_gate = torch.sigmoid(address_params[:, :1]) 219 | flow_gate = torch.softmax(address_params[:, 1:], dim=1) 220 | 221 | # content_attn_score: (batch, max_attn_len) 222 | content_attn_score, content_attn_memory = self.attn(h_attn_lstm, 223 | memory_keys, memory_values, enc_masks) 224 | 225 | # location attention flow: (batch, max_attn_len) 226 | flow_attn_score_1 = torch.einsum('bts,bs->bt', flow_edges, prev_attn_score) 227 | flow_attn_score_2 = torch.einsum('bts,bs->bt', flow_edges, flow_attn_score_1) 228 | # (batch, max_attn_len, 3) 229 | flow_attn_score = torch.stack([x.view(batch_size, max_attn_len) \ 230 | for x in [prev_attn_score, flow_attn_score_1, flow_attn_score_2]], 2) 231 | flow_attn_score = torch.sum(flow_gate.unsqueeze(1) * flow_attn_score, 2) 232 | 233 | # content + location interpolation 234 | attn_score = interpolate_gate * content_attn_score + (1 - interpolate_gate) * flow_attn_score 235 | 236 | # final attention 237 | step_attns.append(attn_score) 238 | prev_attn_score = attn_score 239 | attn_memory = torch.sum(attn_score.unsqueeze(2) * memory_values, 1) 240 | 241 | # next layer with attended context 242 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 243 | torch.cat([h_attn_lstm, attn_memory], dim=1), 244 | (states[0][1], states[1][1])) 245 | 246 | # write: update memory keys and values 247 | individual_vectors = torch.cat([h_lang_lstm.unsqueeze(1).expand(-1, max_attn_len, -1), enc_memories], 2) 248 | update_vectors = self.memory_update_layer(individual_vectors) 249 | memory_size = update_vectors.size(-1) // 2 250 | erase_gates = torch.sigmoid(update_vectors[:, :, :memory_size]) 251 | add_vectors = update_vectors[:, :, memory_size:] 252 | 253 | # some words do not need to attend on visual nodes 254 | sentinal_gates = torch.sigmoid(self.sentinal_layer(h_lang_lstm)) 255 | memory_attn_score = attn_score * sentinal_gates 256 | 257 | enc_memories = enc_memories * (1 - memory_attn_score.unsqueeze(2) * erase_gates) \ 258 | + memory_attn_score.unsqueeze(2) * add_vectors 259 | 260 | outs = h_lang_lstm 261 | logit = self.calc_logits_with_rnn_outs(outs) 262 | step_outs.append(logit) 263 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 264 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 265 | 266 | logits = torch.stack(step_outs, 1) 267 | logits = logits.view(-1, self.config.num_words) 268 | 269 | if return_attn: 270 | return logits, step_attns 271 | return logits 272 | 273 | def step_fn(self, words, step, **kwargs): 274 | states = kwargs['states'] 275 | enc_globals = kwargs['enc_globals'] 276 | enc_memories = kwargs['enc_memories'] 277 | memory_masks = kwargs['memory_masks'] 278 | prev_attn_score = kwargs['prev_attn_score'] 279 | flow_edges = kwargs['flow_edges'] 280 | 281 | batch_size, max_attn_len = memory_masks.size() 282 | memory_keys, memory_values = self.gen_memory_key_value(enc_memories) 283 | embed = self.embedding(words.squeeze(1)) 284 | 285 | h_attn_lstm, c_attn_lstm = self.attn_lstm( 286 | torch.cat([states[0][1], enc_globals, embed], dim=1), 287 | (states[0][0], states[1][0])) 288 | 289 | prev_attn_memory = torch.sum(prev_attn_score.unsqueeze(2) * memory_values, 1) 290 | address_params = self.address_layer(torch.cat([h_attn_lstm, prev_attn_memory], 1)) 291 | interpolate_gate = torch.sigmoid(address_params[:, :1]) 292 | flow_gate = torch.softmax(address_params[:, 1:], dim=1) 293 | 294 | # content_attn_score: (batch, max_attn_len) 295 | content_attn_score, content_attn_memory = self.attn(h_attn_lstm, 296 | memory_keys, memory_values, memory_masks) 297 | 298 | # location attention flow: (batch, max_attn_len) 299 | flow_attn_score_1 = torch.einsum('bts,bs->bt', flow_edges, prev_attn_score) 300 | flow_attn_score_2 = torch.einsum('bts,bs->bt', flow_edges, flow_attn_score_1) 301 | flow_attn_score = torch.stack([x.view(batch_size, max_attn_len) \ 302 | for x in [prev_attn_score, flow_attn_score_1, flow_attn_score_2]], 2) 303 | flow_attn_score = torch.sum(flow_gate.unsqueeze(1) * flow_attn_score, 2) 304 | 305 | # content + location interpolation 306 | attn_score = interpolate_gate * content_attn_score + (1 - interpolate_gate) * flow_attn_score 307 | 308 | # final attention 309 | attn_memory = torch.sum(attn_score.unsqueeze(2) * memory_values, 1) 310 | 311 | h_lang_lstm, c_lang_lstm = self.lang_lstm( 312 | torch.cat([h_attn_lstm, attn_memory], dim=1), 313 | (states[0][1], states[1][1])) 314 | 315 | logits = self.calc_logits_with_rnn_outs(h_lang_lstm) 316 | logprobs = self.log_softmax(logits) 317 | states = (torch.stack([h_attn_lstm, h_lang_lstm], dim=0), 318 | torch.stack([c_attn_lstm, c_lang_lstm], dim=0)) 319 | 320 | # write: update memory keys and values 321 | individual_vectors = torch.cat([h_lang_lstm.unsqueeze(1).expand(-1, max_attn_len, -1), enc_memories], 2) 322 | update_vectors = self.memory_update_layer(individual_vectors) 323 | memory_size = update_vectors.size(-1) // 2 324 | erase_gates = torch.sigmoid(update_vectors[:, :, :memory_size]) 325 | add_vectors = update_vectors[:, :, memory_size:] 326 | 327 | sentinal_gates = torch.sigmoid(self.sentinal_layer(h_lang_lstm)) 328 | memory_attn_score = attn_score * sentinal_gates 329 | enc_memories = enc_memories * (1 - memory_attn_score.unsqueeze(2) * erase_gates) \ 330 | + memory_attn_score.unsqueeze(2) * add_vectors 331 | 332 | kwargs['states'] = states 333 | kwargs['enc_memories'] = enc_memories 334 | kwargs['prev_attn_score'] = attn_score 335 | return logprobs, kwargs 336 | 337 | def sample_decode(self, words, enc_globals, enc_memories, enc_masks, flow_edges, greedy=True): 338 | batch_size, max_attn_len = enc_masks.size() 339 | device = enc_masks.device 340 | 341 | states = self.init_dec_state(batch_size) 342 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 343 | prev_attn_score[:, 0] = 1 344 | 345 | seq_words, seq_word_logprobs = caption.utils.inference.sample_decode( 346 | words, self.step_fn, self.config.max_words_in_sent, 347 | greedy=greedy, states=states, enc_globals=enc_globals, 348 | enc_memories=enc_memories, memory_masks=enc_masks, 349 | prev_attn_score=prev_attn_score, flow_edges=flow_edges) 350 | 351 | return seq_words, seq_word_logprobs 352 | 353 | def beam_search_decode(self, words, enc_globals, enc_memories, enc_masks, flow_edges): 354 | batch_size, max_attn_len = enc_masks.size() 355 | device = enc_masks.device 356 | 357 | states = self.init_dec_state(batch_size) 358 | prev_attn_score = torch.zeros((batch_size, max_attn_len)).to(device) 359 | prev_attn_score[:, 0] = 1 360 | 361 | sent_pool = caption.utils.inference.beam_search_decode(words, self.step_fn, 362 | self.config.max_words_in_sent, beam_width=self.config.beam_width, 363 | sent_pool_size=self.config.sent_pool_size, 364 | expand_fn=self.expand_fn, select_fn=self.select_fn, 365 | enc_memories=enc_memories, memory_masks=enc_masks, 366 | states=states, enc_globals=enc_globals, 367 | prev_attn_score=prev_attn_score, flow_edges=flow_edges) 368 | 369 | return sent_pool 370 | -------------------------------------------------------------------------------- /controlimcap/driver/asg2caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import numpy as np 6 | 7 | import torch.utils.data.dataloader as dataloader 8 | import framework.logbase 9 | import framework.run_utils 10 | 11 | import caption.models.attention 12 | 13 | import controlimcap.readers.imgsgreader as imgsgreader 14 | import controlimcap.models.graphattn 15 | import controlimcap.models.graphflow 16 | import controlimcap.models.graphmemory 17 | import controlimcap.models.flatattn 18 | 19 | from controlimcap.models.graphattn import ATTNENCODER 20 | 21 | from controlimcap.driver.common import build_parser, evaluate_caption 22 | 23 | def main(): 24 | parser = build_parser() 25 | parser.add_argument('--max_attn_len', type=int, default=10) 26 | parser.add_argument('--num_workers', type=int, default=0) 27 | opts = parser.parse_args() 28 | 29 | if opts.mtype == 'node': 30 | model_cfg = caption.models.attention.AttnModelConfig() 31 | elif opts.mtype == 'node.role': 32 | model_cfg = controlimcap.models.flatattn.NodeRoleBUTDAttnModelConfig() 33 | elif opts.mtype in ['rgcn', 'rgcn.flow', 'rgcn.memory', 'rgcn.flow.memory']: 34 | model_cfg = controlimcap.models.graphattn.GraphModelConfig() 35 | model_cfg.load(opts.model_cfg_file) 36 | max_words_in_sent = model_cfg.subcfgs['decoder'].max_words_in_sent 37 | 38 | path_cfg = framework.run_utils.gen_common_pathcfg(opts.path_cfg_file, is_train=opts.is_train) 39 | 40 | if path_cfg.log_file is not None: 41 | _logger = framework.logbase.set_logger(path_cfg.log_file, 'trn_%f'%time.time()) 42 | else: 43 | _logger = None 44 | 45 | if opts.mtype == 'node': 46 | model_fn = controlimcap.models.flatattn.NodeBUTDAttnModel 47 | elif opts.mtype == 'node.role': 48 | model_fn = controlimcap.models.flatattn.NodeRoleBUTDAttnModel 49 | elif opts.mtype == 'rgcn': 50 | model_fn = controlimcap.models.graphattn.RoleGraphBUTDAttnModel 51 | model_cfg.subcfgs[ATTNENCODER].max_attn_len = opts.max_attn_len 52 | elif opts.mtype == 'rgcn.flow': 53 | model_fn = controlimcap.models.graphflow.RoleGraphBUTDCFlowAttnModel 54 | model_cfg.subcfgs[ATTNENCODER].max_attn_len = opts.max_attn_len 55 | elif opts.mtype == 'rgcn.memory': 56 | model_fn = controlimcap.models.graphmemory.RoleGraphBUTDMemoryModel 57 | model_cfg.subcfgs[ATTNENCODER].max_attn_len = opts.max_attn_len 58 | elif opts.mtype == 'rgcn.flow.memory': 59 | model_fn = controlimcap.models.graphmemory.RoleGraphBUTDMemoryFlowModel 60 | model_cfg.subcfgs[ATTNENCODER].max_attn_len = opts.max_attn_len 61 | 62 | _model = model_fn(model_cfg, _logger=_logger, 63 | int2word_file=path_cfg.int2word_file, eval_loss=opts.eval_loss) 64 | 65 | if opts.mtype in ['node', 'node.role']: 66 | reader_fn = imgsgreader.ImageSceneGraphFlatReader 67 | collate_fn = imgsgreader.flat_collate_fn 68 | elif opts.mtype in ['rgcn', 'rgcn.memory']: 69 | reader_fn = imgsgreader.ImageSceneGraphReader 70 | collate_fn = imgsgreader.sg_sparse_collate_fn 71 | elif opts.mtype in ['rgcn.flow', 'rgcn.flow.memory']: 72 | reader_fn = imgsgreader.ImageSceneGraphFlowReader 73 | collate_fn = imgsgreader.sg_sparse_flow_collate_fn 74 | 75 | if opts.is_train: 76 | model_cfg.save(os.path.join(path_cfg.log_dir, 'model.cfg')) 77 | path_cfg.save(os.path.join(path_cfg.log_dir, 'path.cfg')) 78 | json.dump(vars(opts), open(os.path.join(path_cfg.log_dir, 'opts.cfg'), 'w'), indent=2) 79 | 80 | trn_dataset = reader_fn(path_cfg.name_file['trn'], path_cfg.mp_ft_file['trn'], 81 | path_cfg.obj_ft_dir['trn'], path_cfg.region_anno_dir['trn'], path_cfg.word2int_file, 82 | max_attn_len=opts.max_attn_len, max_words_in_sent=max_words_in_sent, 83 | is_train=True, return_label=True, _logger=_logger) 84 | trn_reader = dataloader.DataLoader(trn_dataset, batch_size=model_cfg.trn_batch_size, 85 | shuffle=True, collate_fn=collate_fn, num_workers=opts.num_workers) 86 | val_dataset = reader_fn(path_cfg.name_file['val'], path_cfg.mp_ft_file['val'], 87 | path_cfg.obj_ft_dir['val'], path_cfg.region_anno_dir['trn'], path_cfg.word2int_file, 88 | max_attn_len=opts.max_attn_len, max_words_in_sent=max_words_in_sent, 89 | is_train=False, return_label=True, _logger=_logger) 90 | val_reader = dataloader.DataLoader(val_dataset, batch_size=model_cfg.tst_batch_size, 91 | shuffle=True, collate_fn=collate_fn, num_workers=opts.num_workers) 92 | 93 | _model.train(trn_reader, val_reader, path_cfg.model_dir, path_cfg.log_dir, 94 | resume_file=opts.resume_file) 95 | 96 | else: 97 | tst_dataset = reader_fn(path_cfg.name_file[opts.eval_set], path_cfg.mp_ft_file[opts.eval_set], 98 | path_cfg.obj_ft_dir[opts.eval_set], path_cfg.region_anno_dir[opts.eval_set], 99 | path_cfg.word2int_file, max_attn_len=opts.max_attn_len, max_words_in_sent=max_words_in_sent, 100 | is_train=False, return_label=False, _logger=None) 101 | tst_reader = dataloader.DataLoader(tst_dataset, batch_size=model_cfg.tst_batch_size, 102 | shuffle=False, collate_fn=collate_fn, num_workers=opts.num_workers) 103 | 104 | tmp_ref_captions = json.load(open(os.path.join(os.path.dirname(path_cfg.region_anno_dir[opts.eval_set]), 'cleaned_%s_region_descriptions.json'%opts.eval_set))) 105 | ref_captions = {} 106 | for k, v in tmp_ref_captions.items(): 107 | for rk, rv in v.items(): 108 | ref_captions['%s_%s'%(k, rk)] = [rv['phrase']] 109 | 110 | model_str_scores = [] 111 | if opts.resume_file is None: 112 | model_files = framework.run_utils.find_best_val_models(path_cfg.log_dir, path_cfg.model_dir) 113 | else: 114 | model_files = {'predefined': opts.resume_file} 115 | 116 | for measure_name, model_file in model_files.items(): 117 | set_pred_dir = os.path.join(path_cfg.pred_dir, opts.eval_set) 118 | if not os.path.exists(set_pred_dir): 119 | os.makedirs(set_pred_dir) 120 | tst_pred_file = os.path.join(set_pred_dir, 121 | os.path.splitext(os.path.basename(model_file))[0]+'.json') 122 | 123 | if not os.path.exists(tst_pred_file): 124 | _model.test(tst_reader, tst_pred_file, tst_model_file=model_file, 125 | outcap_format=opts.outcap_format) 126 | if not opts.no_evaluate: 127 | scores = evaluate_caption( 128 | None, tst_pred_file, ref_caps=ref_captions, 129 | outcap_format=opts.outcap_format) 130 | str_scores = [measure_name, os.path.basename(model_file)] 131 | for score_name in ['num_words', 'bleu4', 'meteor', 'rouge', 'cider', 'spice', 'avg_lens']: 132 | str_scores.append('%.2f'%(scores[score_name])) 133 | str_scores = ','.join(str_scores) 134 | print(str_scores) 135 | model_str_scores.append(str_scores) 136 | 137 | if not opts.no_evaluate: 138 | score_log_file = os.path.join(path_cfg.pred_dir, opts.eval_set, 'scores.csv') 139 | with open(score_log_file, 'a') as f: 140 | for str_scores in model_str_scores: 141 | print(str_scores, file=f) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /controlimcap/driver/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from eval_cap.bleu.bleu import Bleu 7 | from eval_cap.meteor.meteor import Meteor 8 | from eval_cap.cider.cider import Cider 9 | from eval_cap.rouge.rouge import Rouge 10 | from eval_cap.spice.spice import Spice 11 | 12 | 13 | def build_parser(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('model_cfg_file') 16 | parser.add_argument('path_cfg_file') 17 | parser.add_argument('mtype') 18 | parser.add_argument('--resume_file', default=None) 19 | parser.add_argument('--selfcritic', action='store_true', default=False) 20 | parser.add_argument('--eval_loss', action='store_true', default=False) 21 | parser.add_argument('--is_train', action='store_true', default=False) 22 | 23 | parser.add_argument('--eval_set', default='tst') 24 | parser.add_argument('--no_evaluate', action='store_true', default=False) 25 | parser.add_argument('--outcap_format', type=int, default=0) 26 | 27 | return parser 28 | 29 | 30 | def evaluate_caption(ref_cap_file, pred_cap_file, ref_caps=None, 31 | preds=None, scorer_names=None, outcap_format=0): 32 | if ref_caps is None: 33 | ref_caps = json.load(open(ref_cap_file)) 34 | if preds is None: 35 | preds = json.load(open(pred_cap_file)) 36 | 37 | if outcap_format == 1: 38 | outs = {} 39 | for key, value in preds.items(): 40 | outs[key] = [value[0]] 41 | preds = outs 42 | elif outcap_format in [2, 3, 4]: 43 | outs = {} 44 | for key, value in preds.items(): 45 | outs[key] = [value[0][0]] 46 | preds = outs 47 | 48 | refs = {} 49 | for key in preds.keys(): 50 | refs[key] = ref_caps[key] 51 | 52 | scorers = { 53 | 'bleu4': Bleu(4), 54 | 'meteor': Meteor(), 55 | 'rouge': Rouge(), 56 | 'cider': Cider(), 57 | 'spice': Spice(), 58 | } 59 | if scorer_names is None: 60 | scorer_names = list(scorers.keys()) 61 | 62 | scores = {} 63 | for measure_name in scorer_names: 64 | scorer = scorers[measure_name] 65 | s, _ = scorer.compute_score(refs, preds) 66 | if measure_name == 'bleu4': 67 | scores[measure_name] = s[-1] * 100 68 | else: 69 | scores[measure_name] = s * 100 70 | 71 | scorers['meteor'].meteor_p.kill() 72 | unique_words = set() 73 | sent_lens = [] 74 | for key, value in preds.items(): 75 | for sent in value: 76 | unique_words.update(sent.split()) 77 | sent_lens.append(len(sent.split())) 78 | scores['num_words'] = len(unique_words) 79 | scores['avg_lens'] = np.mean(sent_lens) 80 | return scores 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /controlimcap/driver/configs/prepare_coco_imgsg_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import json 6 | import numpy as np 7 | 8 | import caption.encoders.vanilla 9 | import caption.decoders.vanilla 10 | import caption.models.vanilla 11 | import caption.models.attention 12 | 13 | import controlimcap.encoders.gcn 14 | import controlimcap.models.flatattn 15 | import controlimcap.models.graphattn 16 | import controlimcap.models.graphmemory 17 | 18 | ENCODER = 'encoder' 19 | DECODER = 'decoder' 20 | MPENCODER = 'mp_encoder' 21 | ATTNENCODER = 'attn_encoder' 22 | 23 | ROOT_DIR = '/data1/csz/MSCOCO' 24 | 25 | def gen_vanilla_encoder_cfg(enc_cfg, dim_fts, dim_embed): 26 | enc_cfg.dim_fts = dim_fts 27 | enc_cfg.dim_embed = dim_embed 28 | enc_cfg.is_embed = True 29 | enc_cfg.dropout = 0 30 | enc_cfg.norm = False 31 | enc_cfg.nonlinear = False 32 | return enc_cfg 33 | 34 | def gen_gcn_encoder_cfg(enc_cfg, dim_input, dim_hidden): 35 | enc_cfg.dim_input = dim_input 36 | enc_cfg.dim_hidden = dim_hidden 37 | enc_cfg.num_rels = 6 38 | enc_cfg.num_bases = -1 39 | enc_cfg.num_hidden_layers = 2 40 | enc_cfg.max_attn_len = 10 41 | enc_cfg.self_loop = True 42 | enc_cfg.num_node_types = 3 43 | enc_cfg.embed_first = True 44 | return enc_cfg 45 | 46 | def gen_vanilla_decoder_cfg(dec_cfg, num_words, hidden_size): 47 | dec_cfg.rnn_type = 'lstm' 48 | dec_cfg.num_words = num_words 49 | dec_cfg.dim_word = 512 50 | dec_cfg.hidden_size = hidden_size 51 | dec_cfg.num_layers = 1 52 | dec_cfg.hidden2word = False 53 | dec_cfg.tie_embed = True 54 | dec_cfg.fix_word_embed = False 55 | dec_cfg.max_words_in_sent = 25 56 | dec_cfg.dropout = 0.5 57 | dec_cfg.schedule_sampling = False 58 | dec_cfg.ss_rate = 0. 59 | dec_cfg.ss_increase_rate = 0.05 60 | dec_cfg.ss_max_rate = 0.25 61 | dec_cfg.ss_increase_epoch = 5 62 | 63 | dec_cfg.greedy_or_beam = False # test method 64 | dec_cfg.beam_width = 1 65 | dec_cfg.sent_pool_size = 1 66 | return dec_cfg 67 | 68 | def gen_attn_decoder_cfg(dec_cfg, num_words, hidden_size): 69 | dec_cfg = gen_vanilla_decoder_cfg(dec_cfg, num_words, hidden_size) 70 | dec_cfg.memory_same_key_value = True 71 | dec_cfg.attn_input_size = 512 72 | dec_cfg.attn_size = 512 73 | dec_cfg.attn_type = 'mlp' 74 | return dec_cfg 75 | 76 | def gen_common_model_cfg(model_cfg): 77 | model_cfg.trn_batch_size = 128 78 | model_cfg.tst_batch_size = 128 79 | model_cfg.num_epoch = 100 80 | model_cfg.base_lr = 2e-4 81 | model_cfg.monitor_iter = 1000 82 | model_cfg.summary_iter = 1000 83 | model_cfg.save_iter = -1 84 | model_cfg.val_iter = -1 85 | model_cfg.val_per_epoch = True 86 | model_cfg.save_per_epoch = True 87 | return model_cfg 88 | 89 | 90 | def prepare_attention(mtype): 91 | anno_dir = os.path.join(ROOT_DIR, 'annotation', 'controllable') 92 | mp_ft_dir = os.path.join(ROOT_DIR, 'ordered_feature', 'MP') 93 | attn_ft_dir = os.path.join(ROOT_DIR, 'ordered_feature', 'SA') 94 | split_dir = os.path.join(anno_dir, 'public_split') 95 | res_dir = os.path.join(ROOT_DIR, 'results', 'ControlCAP') 96 | 97 | mp_ft_name = 'resnet101.ctrl' 98 | attn_ft_name = 'X_101_32x8d' 99 | 100 | hidden_size = 512 101 | dim_attn_ft = 2048 102 | dim_mp_ft = np.load(os.path.join(mp_ft_dir, mp_ft_name, 'val_ft.npy')).shape[-1] 103 | num_words = len(np.load(os.path.join(anno_dir, 'int2word.npy'))) 104 | 105 | if mtype == 'node': 106 | model_cfg = caption.models.attention.AttnModelConfig() 107 | elif mtype == 'node.role': 108 | model_cfg = controlimcap.models.flatattn.NodeRoleBUTDAttnModelConfig() 109 | elif mtype in ['rgcn', 'rgcn.memory', 'rgcn.flow', 'rgcn.flow.memory']: 110 | model_cfg = controlimcap.models.graphattn.GraphModelConfig() 111 | 112 | model_cfg = gen_common_model_cfg(model_cfg) 113 | 114 | mp_enc_cfg = gen_vanilla_encoder_cfg(model_cfg.subcfgs[MPENCODER], [dim_mp_ft, hidden_size], hidden_size) 115 | 116 | if mtype in ['node', 'node.role']: 117 | attn_enc_cfg = gen_vanilla_encoder_cfg(model_cfg.subcfgs[ATTNENCODER], [dim_attn_ft], hidden_size) 118 | if mtype == 'node.role': 119 | attn_enc_cfg.num_node_types = 3 120 | elif mtype in ['rgcn', 'rgcn.memory', 'rgcn.flow', 'rgcn.flow.memory']: 121 | attn_enc_cfg = gen_gcn_encoder_cfg(model_cfg.subcfgs[ATTNENCODER], dim_attn_ft, hidden_size) 122 | attn_enc_cfg.num_node_types = 3 123 | 124 | dec_cfg = gen_attn_decoder_cfg(model_cfg.subcfgs[DECODER], num_words, hidden_size) 125 | 126 | output_dir = os.path.join(res_dir, mtype, 127 | 'mp.%s.attn.%s.%s%s.layer.%d.hidden.%d%s%s%s'% 128 | (mp_ft_name, attn_ft_name, 129 | 'rgcn.%d.'%attn_enc_cfg.num_hidden_layers if 'rgcn' in opts.mtype else '', 130 | dec_cfg.rnn_type, 131 | dec_cfg.num_layers, dec_cfg.hidden_size, 132 | '.hidden2word' if dec_cfg.hidden2word else '', 133 | '.tie_embed' if dec_cfg.tie_embed else '', 134 | '.schedule_sampling' if dec_cfg.schedule_sampling else '') 135 | ) 136 | 137 | if 'rgcn' in mtype: 138 | output_dir = '%s%s'%(output_dir, '.embed_first' if attn_enc_cfg.embed_first else '') 139 | 140 | print(output_dir) 141 | 142 | if not os.path.exists(output_dir): 143 | os.makedirs(output_dir) 144 | model_cfg.save(os.path.join(output_dir, 'model.json')) 145 | 146 | path_cfg = { 147 | 'output_dir': output_dir, 148 | 'mp_ft_file': {}, 149 | 'name_file': {}, 150 | 'word2int_file': os.path.join(anno_dir, 'word2int.json'), 151 | 'int2word_file': os.path.join(anno_dir, 'int2word.npy'), 152 | 'region_anno_dir': {}, 153 | 'obj_ft_dir': {}, 154 | } 155 | for setname in ['trn', 'val', 'tst']: 156 | path_cfg['mp_ft_file'][setname] = os.path.join(mp_ft_dir, mp_ft_name, '%s_ft.npy'%setname) 157 | path_cfg['obj_ft_dir'][setname] = os.path.join(attn_ft_dir, attn_ft_name, 'objrels') 158 | path_cfg['name_file'][setname] = os.path.join(split_dir, '%s_names.npy'%setname) 159 | path_cfg['region_anno_dir'][setname] = os.path.join(anno_dir, 'regionfiles') 160 | 161 | with open(os.path.join(output_dir, 'path.json'), 'w') as f: 162 | json.dump(path_cfg, f, indent=2) 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('mtype', 168 | choices=['node', 'node.role', 'rgcn', 'rgcn.flow', 'rgcn.memory', 'rgcn.flow.memory']) 169 | opts = parser.parse_args() 170 | 171 | prepare_attention(opts.mtype) 172 | 173 | -------------------------------------------------------------------------------- /controlimcap/driver/configs/prepare_vg_imgsg_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import json 6 | import numpy as np 7 | 8 | import caption.encoders.vanilla 9 | import caption.decoders.vanilla 10 | import caption.models.vanilla 11 | import caption.models.attention 12 | import caption.models.selfcritic 13 | 14 | import controlimcap.encoders.gcn 15 | import controlimcap.models.flatattn 16 | import controlimcap.models.graphattn 17 | import controlimcap.models.graphmemory 18 | 19 | ENCODER = 'encoder' 20 | DECODER = 'decoder' 21 | MPENCODER = 'mp_encoder' 22 | ATTNENCODER = 'attn_encoder' 23 | 24 | ROOT_DIR = '/data4/VisualGenome' 25 | 26 | def gen_vanilla_encoder_cfg(enc_cfg, dim_fts, dim_embed): 27 | enc_cfg.dim_fts = dim_fts 28 | enc_cfg.dim_embed = dim_embed 29 | enc_cfg.is_embed = True 30 | enc_cfg.dropout = 0 31 | enc_cfg.norm = False 32 | enc_cfg.nonlinear = False 33 | return enc_cfg 34 | 35 | def gen_gcn_encoder_cfg(enc_cfg, dim_input, dim_hidden): 36 | enc_cfg.dim_input = dim_input 37 | enc_cfg.dim_hidden = dim_hidden 38 | enc_cfg.num_rels = 6 39 | enc_cfg.num_bases = -1 40 | enc_cfg.num_hidden_layers = 2 41 | enc_cfg.max_attn_len = 10 42 | enc_cfg.self_loop = True 43 | enc_cfg.num_node_types = 3 44 | enc_cfg.embed_first = False #True 45 | return enc_cfg 46 | 47 | def gen_vanilla_decoder_cfg(dec_cfg, num_words, hidden_size): 48 | dec_cfg.rnn_type = 'lstm' 49 | dec_cfg.num_words = num_words 50 | dec_cfg.dim_word = 512 51 | dec_cfg.hidden_size = hidden_size 52 | dec_cfg.num_layers = 1 53 | dec_cfg.hidden2word = False 54 | dec_cfg.tie_embed = True 55 | dec_cfg.fix_word_embed = False 56 | dec_cfg.max_words_in_sent = 15 57 | dec_cfg.dropout = 0.5 58 | dec_cfg.schedule_sampling = False 59 | dec_cfg.ss_rate = 0. 60 | dec_cfg.ss_increase_rate = 0.05 61 | dec_cfg.ss_max_rate = 0.25 62 | dec_cfg.ss_increase_epoch = 5 63 | 64 | dec_cfg.greedy_or_beam = False # test method 65 | dec_cfg.beam_width = 1 66 | dec_cfg.sent_pool_size = 1 67 | return dec_cfg 68 | 69 | def gen_attn_decoder_cfg(dec_cfg, num_words, hidden_size): 70 | dec_cfg = gen_vanilla_decoder_cfg(dec_cfg, num_words, hidden_size) 71 | dec_cfg.memory_same_key_value = True 72 | dec_cfg.attn_size = 512 73 | dec_cfg.attn_type = 'mlp' 74 | return dec_cfg 75 | 76 | def gen_common_model_cfg(model_cfg): 77 | model_cfg.trn_batch_size = 128 78 | model_cfg.tst_batch_size = 128 79 | model_cfg.num_epoch = 20 80 | model_cfg.base_lr = 2e-4 81 | model_cfg.monitor_iter = 1000 82 | model_cfg.summary_iter = 1000 83 | model_cfg.save_iter = -1 #5000 84 | model_cfg.val_iter = -1 #5000 85 | model_cfg.val_per_epoch = True 86 | model_cfg.save_per_epoch = True 87 | return model_cfg 88 | 89 | 90 | def prepare_attention(mtype): 91 | anno_dir = os.path.join(ROOT_DIR, 'annotation') 92 | mp_ft_dir = os.path.join(ROOT_DIR, 'ordered_feature', 'MP') 93 | attn_ft_dir = os.path.join(ROOT_DIR, 'ordered_feature', 'SA') 94 | split_dir = os.path.join(ROOT_DIR, 'public_split') 95 | res_dir = os.path.join(ROOT_DIR, 'results', 'CAP.ctrl') 96 | 97 | mp_ft_name = 'resnet101.ctrl' 98 | attn_ft_name = 'X_101_32x8d' 99 | 100 | hidden_size = 512 101 | dim_attn_ft = 2048 102 | dim_mp_ft = np.load(os.path.join(mp_ft_dir, mp_ft_name, 'val_ft.npy')).shape[-1] 103 | num_words = len(np.load(os.path.join(anno_dir, 'int2word.npy'))) 104 | 105 | if mtype == 'node': 106 | model_cfg = caption.models.attention.AttnModelConfig() 107 | elif mtype in ['node.role']: 108 | model_cfg = controlimcap.models.flatattn.NodeRoleBUTDAttnModelConfig() 109 | model_cfg.semantic_loss_weight = 10 110 | elif mtype in ['rgcn', 'rgcn.memory', 'rgcn.flow', 'rgcn.flow.memory']: 111 | model_cfg = controlimcap.models.graphattn.GraphModelConfig() 112 | 113 | model_cfg = gen_common_model_cfg(model_cfg) 114 | 115 | mp_enc_cfg = gen_vanilla_encoder_cfg(model_cfg.subcfgs[MPENCODER], [dim_mp_ft, hidden_size], hidden_size) 116 | if mtype in ['node', 'node.role']: 117 | attn_enc_cfg = gen_vanilla_encoder_cfg(model_cfg.subcfgs[ATTNENCODER], [dim_attn_ft], hidden_size) 118 | if mtype == 'node.role': 119 | attn_enc_cfg.num_node_types = 3 120 | elif mtype in ['rgcn', 'rgcn.memory', 'rgcn.flow', 'rgcn.flow.memory']: 121 | attn_enc_cfg = gen_gcn_encoder_cfg(model_cfg.subcfgs[ATTNENCODER], dim_attn_ft, hidden_size) 122 | attn_enc_cfg.num_node_types = 3 123 | 124 | dec_cfg = gen_attn_decoder_cfg(model_cfg.subcfgs[DECODER], num_words, hidden_size) 125 | 126 | output_dir = os.path.join(res_dir, mtype, 127 | 'mp.%s.attn.%s.%s%s.layer.%d.hidden.%d%s%s%s'% 128 | ( mp_ft_name, attn_ft_name, 129 | 'rgcn.%d.'%attn_enc_cfg.num_hidden_layers if 'rgcn' in opts.mtype else '', 130 | dec_cfg.rnn_type, 131 | dec_cfg.num_layers, dec_cfg.hidden_size, 132 | '.hidden2word' if dec_cfg.hidden2word else '', 133 | '.tie_embed' if dec_cfg.tie_embed else '', 134 | '.schedule_sampling' if dec_cfg.schedule_sampling else '') 135 | ) 136 | 137 | if 'rgcn' in mtype: 138 | output_dir = '%s%s'%(output_dir, 139 | '.embed_first' if attn_enc_cfg.embed_first else '') 140 | 141 | print(output_dir) 142 | 143 | if not os.path.exists(output_dir): 144 | os.makedirs(output_dir) 145 | model_cfg.save(os.path.join(output_dir, 'model.json')) 146 | 147 | path_cfg = { 148 | 'output_dir': output_dir, 149 | 'mp_ft_file': {}, 150 | 'name_file': {}, 151 | 'word2int_file': os.path.join(anno_dir, 'word2int.json'), 152 | 'int2word_file': os.path.join(anno_dir, 'int2word.npy'), 153 | 'region_anno_dir': {}, 154 | 'obj_ft_dir': {}, 155 | } 156 | for setname in ['trn', 'val', 'tst']: 157 | path_cfg['mp_ft_file'][setname] = os.path.join(mp_ft_dir, mp_ft_name, '%s_ft.npy'%setname) 158 | path_cfg['obj_ft_dir'][setname] = os.path.join(attn_ft_dir, attn_ft_name, 'objrels') 159 | path_cfg['name_file'][setname] = os.path.join(split_dir, '%s_names.npy'%setname) 160 | path_cfg['region_anno_dir'][setname] = os.path.join(anno_dir, 'regionfiles') 161 | 162 | with open(os.path.join(output_dir, 'path.json'), 'w') as f: 163 | json.dump(path_cfg, f, indent=2) 164 | 165 | 166 | if __name__ == '__main__': 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument('mtype', 169 | choices=['node', 'node.role', 'rgcn', 'rgcn.flow', 'rgcn.memory', 'rgcn.flow.memory']) 170 | opts = parser.parse_args() 171 | 172 | prepare_attention(opts.mtype) 173 | 174 | -------------------------------------------------------------------------------- /controlimcap/encoders/flat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import caption.encoders.vanilla 9 | 10 | ''' 11 | EncoderConfig is the same as encoder.vanilla.EncoderConfig 12 | ''' 13 | 14 | def gen_order_embeds(max_len, dim_ft): 15 | order_embeds = np.zeros((max_len, dim_ft)) 16 | position = np.expand_dims(np.arange(0, max_len - 1).astype(np.float32), 1) 17 | div_term = np.exp(np.arange(0, dim_ft, 2) * -(math.log(10000.0) / dim_ft)) 18 | order_embeds[1:, 0::2] = np.sin(position * div_term) 19 | order_embeds[1:, 1::2] = np.cos(position * div_term) 20 | return order_embeds 21 | 22 | class EncoderConfig(caption.encoders.vanilla.EncoderConfig): 23 | def __init__(self): 24 | super().__init__() 25 | self.dim_fts = [2048] 26 | self.dim_embed = 512 27 | self.is_embed = True 28 | self.dropout = 0 29 | self.norm = False 30 | self.nonlinear = False 31 | self.num_node_types = 3 32 | 33 | class Encoder(caption.encoders.vanilla.Encoder): 34 | def __init__(self, config): 35 | super().__init__(config) 36 | 37 | dim_fts = sum(self.config.dim_fts) 38 | self.node_embedding = nn.Embedding(self.config.num_node_types, dim_fts) 39 | 40 | self.register_buffer('attr_order_embeds', 41 | torch.FloatTensor(gen_order_embeds(20, dim_fts))) 42 | 43 | def forward(self, fts, node_types, attr_order_idxs): 44 | ''' 45 | Args: 46 | fts: size=(batch, seq_len, dim_ft) 47 | node_types: size=(batch, seq_len) 48 | attr_order_idxs: size=(batch, seq_len) 49 | Returns: 50 | embeds: size=(batch, seq_len, dim_embed) 51 | ''' 52 | node_embeds = self.node_embedding(node_types) 53 | node_embeds = node_embeds + self.attr_order_embeds[attr_order_idxs] 54 | 55 | inputs = fts * node_embeds 56 | embeds = super().forward(inputs) 57 | 58 | return embeds 59 | -------------------------------------------------------------------------------- /controlimcap/encoders/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import framework.configbase 6 | import framework.ops 7 | 8 | from controlimcap.encoders.flat import gen_order_embeds 9 | 10 | 11 | class RGCNLayer(nn.Module): 12 | def __init__(self, in_feat, out_feat, num_rels, 13 | bias=None, activation=None, dropout=0.0): 14 | super().__init__() 15 | self.in_feat = in_feat 16 | self.out_feat = out_feat 17 | self.num_rels = num_rels 18 | self.bias = bias 19 | self.activation = activation 20 | 21 | self.loop_weight = nn.Parameter(torch.Tensor(self.in_feat, self.out_feat)) 22 | nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu')) 23 | 24 | self.weight = nn.Parameter(torch.Tensor(self.num_rels, self.in_feat, self.out_feat)) 25 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 26 | 27 | if self.bias: 28 | self.bias = nn.Parameter(torch.Tensor(self.out_feat)) 29 | nn.init.xavier_uniform_(self.bias, gain=nn.init.calculate_gain('relu')) 30 | 31 | self.dropout = nn.Dropout(dropout) 32 | 33 | def forward(self, attn_fts, rel_edges): 34 | '''Args: 35 | attn_fts: (batch_size, max_src_nodes, in_feat) 36 | rel_edges: (batch_size, num_rels, max_tgt_nodes, max_srt_nodes) 37 | Retunrs: 38 | node_repr: (batch_size, max_tgt_nodes, out_feat) 39 | ''' 40 | loop_message = torch.einsum('bsi,ij->bsj', attn_fts, self.loop_weight) 41 | loop_message = self.dropout(loop_message) 42 | 43 | neighbor_message = torch.einsum('brts,bsi,rij->btj', rel_edges, attn_fts, self.weight) 44 | 45 | node_repr = loop_message + neighbor_message 46 | if self.bias: 47 | node_repr = node_repr + self.bias 48 | if self.activation: 49 | node_repr = self.activation(node_repr) 50 | 51 | return node_repr 52 | 53 | 54 | class RGCNEncoderConfig(framework.configbase.ModuleConfig): 55 | def __init__(self): 56 | super().__init__() 57 | self.dim_input = 2048 58 | self.dim_hidden = 512 59 | self.num_rels = 6 60 | self.num_hidden_layers = 1 61 | self.max_attn_len = 10 62 | self.self_loop = True 63 | self.dropout = 0. 64 | self.num_node_types = 3 65 | self.embed_first = False 66 | 67 | class RGCNEncoder(nn.Module): 68 | def __init__(self, config): 69 | super().__init__() 70 | self.config = config 71 | 72 | if self.config.embed_first: 73 | self.first_embedding = nn.Sequential( 74 | nn.Linear(self.config.dim_input, self.config.dim_hidden), 75 | nn.ReLU()) 76 | 77 | self.layers = nn.ModuleList() 78 | dim_input = self.config.dim_hidden if self.config.embed_first else self.config.dim_input 79 | for _ in range(self.config.num_hidden_layers): 80 | h2h = RGCNLayer(dim_input, self.config.dim_hidden, self.config.num_rels, 81 | activation=F.relu, dropout=self.config.dropout) 82 | dim_input = self.config.dim_hidden 83 | self.layers.append(h2h) 84 | 85 | def forward(self, attn_fts, rel_edges): 86 | if self.config.embed_first: 87 | attn_fts = self.first_embedding(attn_fts) 88 | 89 | for layer in self.layers: 90 | attn_fts = layer(attn_fts, rel_edges) 91 | 92 | return attn_fts 93 | 94 | 95 | class RoleRGCNEncoder(RGCNEncoder): 96 | def __init__(self, config): 97 | super().__init__(config) 98 | 99 | self.node_embedding = nn.Embedding(self.config.num_node_types, 100 | self.config.dim_input) 101 | 102 | self.register_buffer('attr_order_embeds', 103 | torch.FloatTensor(gen_order_embeds(20, self.config.dim_input))) 104 | 105 | def forward(self, attn_fts, node_types, attr_order_idxs, rel_edges): 106 | '''Args: 107 | (num_src_nodes = num_tgt_nodes) 108 | - attn_fts: (batch_size, num_src_nodes, in_feat) 109 | - rel_edges: (num_rels, num_tgt_nodes, num_src_nodes) 110 | - node_types: (batch_size, num_src_nodes) 111 | - attr_order_idxs: (batch_size, num_src_nodes) 112 | ''' 113 | node_embeds = self.node_embedding(node_types) 114 | node_embeds = node_embeds + self.attr_order_embeds[attr_order_idxs] 115 | 116 | input_fts = attn_fts * node_embeds 117 | 118 | return super().forward(input_fts, rel_edges) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /controlimcap/models/flatattn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import framework.configbase 5 | import caption.encoders.vanilla 6 | import caption.decoders.attention 7 | import caption.models.attention 8 | 9 | import controlimcap.encoders.flat 10 | 11 | from caption.models.attention import MPENCODER, ATTNENCODER, DECODER 12 | 13 | 14 | class NodeBUTDAttnModel(caption.models.attention.BUTDAttnModel): 15 | def forward_encoder(self, input_batch): 16 | attn_embeds = self.submods[ATTNENCODER](input_batch['attn_fts']) 17 | graph_embeds = torch.sum(attn_embeds * input_batch['attn_masks'].unsqueeze(2), 1) 18 | graph_embeds = graph_embeds / torch.sum(input_batch['attn_masks'], 1, keepdim=True) 19 | enc_states = self.submods[MPENCODER]( 20 | torch.cat([input_batch['mp_fts'], graph_embeds], 1)) 21 | return {'init_states': enc_states, 'attn_fts': attn_embeds} 22 | 23 | 24 | class NodeRoleBUTDAttnModelConfig(caption.models.attention.AttnModelConfig): 25 | def __init__(self): 26 | super().__init__() 27 | self.subcfgs[ATTNENCODER] = controlimcap.encoders.flat.EncoderConfig() 28 | 29 | class NodeRoleBUTDAttnModel(caption.models.attention.BUTDAttnModel): 30 | def build_submods(self): 31 | submods = {} 32 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 33 | submods[ATTNENCODER] = controlimcap.encoders.flat.Encoder(self.config.subcfgs[ATTNENCODER]) 34 | submods[DECODER] = caption.decoders.attention.BUTDAttnDecoder(self.config.subcfgs[DECODER]) 35 | return submods 36 | 37 | def prepare_input_batch(self, batch_data, is_train=False): 38 | outs = super().prepare_input_batch(batch_data, is_train=is_train) 39 | outs['node_types'] = torch.LongTensor(batch_data['node_types']).to(self.device) 40 | outs['attr_order_idxs'] = torch.LongTensor(batch_data['attr_order_idxs']).to(self.device) 41 | return outs 42 | 43 | def forward_encoder(self, input_batch): 44 | attn_embeds = self.submods[ATTNENCODER](input_batch['attn_fts'], 45 | input_batch['node_types'], input_batch['attr_order_idxs']) 46 | graph_embeds = torch.sum(attn_embeds * input_batch['attn_masks'].unsqueeze(2), 1) 47 | graph_embeds = graph_embeds / torch.sum(input_batch['attn_masks'], 1, keepdim=True) 48 | enc_states = self.submods[MPENCODER]( 49 | torch.cat([input_batch['mp_fts'], graph_embeds], 1)) 50 | return {'init_states': enc_states, 'attn_fts': attn_embeds} 51 | 52 | -------------------------------------------------------------------------------- /controlimcap/models/graphattn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import framework.configbase 6 | import caption.encoders.vanilla 7 | import caption.decoders.attention 8 | import caption.models.attention 9 | 10 | import controlimcap.encoders.gcn 11 | import controlimcap.decoders.cfattention 12 | 13 | MPENCODER = 'mp_encoder' 14 | ATTNENCODER = 'attn_encoder' 15 | DECODER = 'decoder' 16 | 17 | class GraphModelConfig(framework.configbase.ModelConfig): 18 | def __init__(self): 19 | super().__init__() 20 | self.subcfgs[MPENCODER] = caption.encoders.vanilla.EncoderConfig() 21 | self.subcfgs[ATTNENCODER] = controlimcap.encoders.gcn.RGCNEncoderConfig() 22 | self.subcfgs[DECODER] = caption.decoders.attention.AttnDecoderConfig() 23 | 24 | def _assert(self): 25 | assert self.subcfgs[MPENCODER].dim_embed == self.subcfgs[DECODER].hidden_size 26 | assert self.subcfgs[ATTNENCODER].dim_hidden == self.subcfgs[DECODER].attn_input_size 27 | 28 | 29 | class GraphBUTDAttnModel(caption.models.attention.BUTDAttnModel): 30 | def build_submods(self): 31 | submods = {} 32 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 33 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RGCNEncoder(self.config.subcfgs[ATTNENCODER]) 34 | submods[DECODER] = caption.decoders.attention.BUTDAttnDecoder(self.config.subcfgs[DECODER]) 35 | return submods 36 | 37 | def prepare_input_batch(self, batch_data, is_train=False): 38 | outs = {} 39 | outs['mp_fts'] = torch.FloatTensor(batch_data['mp_fts']).to(self.device) 40 | outs['attn_fts'] = torch.FloatTensor(batch_data['attn_fts']).to(self.device) 41 | outs['attn_masks'] = torch.FloatTensor(batch_data['attn_masks'].astype(np.float32)).to(self.device) 42 | # build rel_edges tensor 43 | batch_size, max_nodes, _ = outs['attn_fts'].size() 44 | num_rels = len(batch_data['edge_sparse_matrices'][0]) 45 | rel_edges = np.zeros((batch_size, num_rels, max_nodes, max_nodes), dtype=np.float32) 46 | for i, edge_sparse_matrices in enumerate(batch_data['edge_sparse_matrices']): 47 | for j, edge_sparse_matrix in enumerate(edge_sparse_matrices): 48 | rel_edges[i, j] = edge_sparse_matrix.todense() 49 | outs['rel_edges'] = torch.FloatTensor(rel_edges).to(self.device) 50 | if is_train: 51 | outs['caption_ids'] = torch.LongTensor(batch_data['caption_ids']).to(self.device) 52 | outs['caption_masks'] = torch.FloatTensor(batch_data['caption_masks'].astype(np.float32)).to(self.device) 53 | if 'gt_attns' in batch_data: 54 | outs['gt_attns'] = torch.FloatTensor(batch_data['gt_attns'].astype(np.float32)).to(self.device) 55 | return outs 56 | 57 | def forward_encoder(self, input_batch): 58 | attn_embeds = self.submods[ATTNENCODER](input_batch['attn_fts'], input_batch['rel_edges']) 59 | graph_embeds = torch.sum(attn_embeds * input_batch['attn_masks'].unsqueeze(2), 1) 60 | graph_embeds = graph_embeds / torch.sum(input_batch['attn_masks'], 1, keepdim=True) 61 | enc_states = self.submods[MPENCODER]( 62 | torch.cat([input_batch['mp_fts'], graph_embeds], 1)) 63 | return {'init_states': enc_states, 'attn_fts': attn_embeds} 64 | 65 | loss = torch.sum(losses * masks) / torch.sum(masks) 66 | return loss 67 | 68 | def forward_loss(self, batch_data, step=None): 69 | input_batch = self.prepare_input_batch(batch_data, is_train=True) 70 | 71 | enc_outs = self.forward_encoder(input_batch) 72 | # logits.shape=(batch*seq_len, num_words) 73 | logits = self.submods[DECODER](input_batch['caption_ids'][:, :-1], 74 | enc_outs['init_states'], enc_outs['attn_fts'], input_batch['attn_masks']) 75 | cap_loss = self.criterion(logits, input_batch['caption_ids'], 76 | input_batch['caption_masks']) 77 | 78 | return cap_loss 79 | 80 | 81 | class RoleGraphBUTDAttnModel(GraphBUTDAttnModel): 82 | def build_submods(self): 83 | submods = {} 84 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 85 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RoleRGCNEncoder(self.config.subcfgs[ATTNENCODER]) 86 | submods[DECODER] = caption.decoders.attention.BUTDAttnDecoder(self.config.subcfgs[DECODER]) 87 | return submods 88 | 89 | def prepare_input_batch(self, batch_data, is_train=False): 90 | outs = super().prepare_input_batch(batch_data, is_train=is_train) 91 | outs['node_types'] = torch.LongTensor(batch_data['node_types']).to(self.device) 92 | outs['attr_order_idxs'] = torch.LongTensor(batch_data['attr_order_idxs']).to(self.device) 93 | return outs 94 | 95 | def forward_encoder(self, input_batch): 96 | attn_embeds = self.submods[ATTNENCODER](input_batch['attn_fts'], 97 | input_batch['node_types'], input_batch['attr_order_idxs'], input_batch['rel_edges']) 98 | graph_embeds = torch.sum(attn_embeds * input_batch['attn_masks'].unsqueeze(2), 1) 99 | graph_embeds = graph_embeds / torch.sum(input_batch['attn_masks'], 1, keepdim=True) 100 | enc_states = self.submods[MPENCODER]( 101 | torch.cat([input_batch['mp_fts'], graph_embeds], 1)) 102 | return {'init_states': enc_states, 'attn_fts': attn_embeds} 103 | 104 | -------------------------------------------------------------------------------- /controlimcap/models/graphflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import caption.encoders.vanilla 5 | import controlimcap.encoders.gcn 6 | import controlimcap.decoders.cfattention 7 | import controlimcap.models.graphattn 8 | 9 | MPENCODER = 'mp_encoder' 10 | ATTNENCODER = 'attn_encoder' 11 | DECODER = 'decoder' 12 | 13 | class GraphBUTDCFlowAttnModel(controlimcap.models.graphattn.GraphBUTDAttnModel): 14 | def build_submods(self): 15 | submods = {} 16 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 17 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RGCNEncoder(self.config.subcfgs[ATTNENCODER]) 18 | submods[DECODER] = controlimcap.decoders.cfattention.ContentFlowAttentionDecoder( 19 | self.config.subcfgs[DECODER]) 20 | return submods 21 | 22 | def prepare_input_batch(self, batch_data, is_train=False): 23 | outs = super().prepare_input_batch(batch_data, is_train=is_train) 24 | flow_edges = [x.toarray() for x in batch_data['flow_sparse_matrix']] 25 | flow_edges = np.stack(flow_edges, 0) 26 | outs['flow_edges'] = torch.FloatTensor(flow_edges).to(self.device) 27 | return outs 28 | 29 | def forward_loss(self, batch_data, step=None): 30 | input_batch = self.prepare_input_batch(batch_data, is_train=True) 31 | 32 | enc_outs = self.forward_encoder(input_batch) 33 | # logits.shape=(batch*seq_len, num_words) 34 | logits = self.submods[DECODER](input_batch['caption_ids'][:, :-1], 35 | enc_outs['init_states'], enc_outs['attn_fts'], input_batch['attn_masks'], 36 | input_batch['flow_edges']) 37 | cap_loss = self.criterion(logits, input_batch['caption_ids'], 38 | input_batch['caption_masks']) 39 | 40 | return cap_loss 41 | 42 | def validate_batch(self, batch_data, addition_outs=None): 43 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 44 | enc_outs = self.forward_encoder(input_batch) 45 | 46 | batch_size = input_batch['attn_masks'].size(0) 47 | init_words = torch.zeros(batch_size, dtype=torch.int64).to(self.device) 48 | 49 | pred_sent, _ = self.submods[DECODER].sample_decode(init_words, 50 | enc_outs['init_states'], enc_outs['attn_fts'], input_batch['attn_masks'], 51 | input_batch['flow_edges'], greedy=True) 52 | 53 | return pred_sent 54 | 55 | def test_batch(self, batch_data, greedy_or_beam): 56 | input_batch = self.prepare_input_batch(batch_data, is_train=False) 57 | enc_outs = self.forward_encoder(input_batch) 58 | 59 | batch_size = input_batch['attn_masks'].size(0) 60 | init_words = torch.zeros(batch_size, dtype=torch.int64).to(self.device) 61 | 62 | if greedy_or_beam: 63 | sent_pool = self.submods[DECODER].beam_search_decode( 64 | init_words, enc_outs['init_states'], enc_outs['attn_fts'], 65 | input_batch['attn_masks'], input_batch['flow_edges']) 66 | pred_sent = [pool[0][1] for pool in sent_pool] 67 | else: 68 | pred_sent, word_logprobs = self.submods[DECODER].sample_decode( 69 | init_words, enc_outs['init_states'], enc_outs['attn_fts'], 70 | input_batch['attn_masks'], input_batch['flow_edges'], greedy=True) 71 | sent_pool = [] 72 | for sent, word_logprob in zip(pred_sent, word_logprobs): 73 | sent_pool.append([(word_logprob.sum().item(), sent, word_logprob)]) 74 | 75 | return pred_sent, sent_pool 76 | 77 | 78 | class RoleGraphBUTDCFlowAttnModel(GraphBUTDCFlowAttnModel): 79 | def build_submods(self): 80 | submods = {} 81 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 82 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RoleRGCNEncoder(self.config.subcfgs[ATTNENCODER]) 83 | submods[DECODER] = controlimcap.decoders.cfattention.ContentFlowAttentionDecoder( 84 | self.config.subcfgs[DECODER]) 85 | return submods 86 | 87 | def prepare_input_batch(self, batch_data, is_train=False): 88 | outs = super().prepare_input_batch(batch_data, is_train=is_train) 89 | outs['node_types'] = torch.LongTensor(batch_data['node_types']).to(self.device) 90 | outs['attr_order_idxs'] = torch.LongTensor(batch_data['attr_order_idxs']).to(self.device) 91 | return outs 92 | 93 | def forward_encoder(self, input_batch): 94 | attn_embeds = self.submods[ATTNENCODER](input_batch['attn_fts'], 95 | input_batch['node_types'], input_batch['attr_order_idxs'], 96 | input_batch['rel_edges']) 97 | graph_embeds = torch.sum(attn_embeds * input_batch['attn_masks'].unsqueeze(2), 1) 98 | graph_embeds = graph_embeds / torch.sum(input_batch['attn_masks'], 1, keepdim=True) 99 | enc_states = self.submods[MPENCODER]( 100 | torch.cat([input_batch['mp_fts'], graph_embeds], 1)) 101 | return {'init_states': enc_states, 'attn_fts': attn_embeds} 102 | -------------------------------------------------------------------------------- /controlimcap/models/graphmemory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import caption.encoders.vanilla 5 | import caption.models.captionbase 6 | 7 | import controlimcap.encoders.gcn 8 | import controlimcap.decoders.memory 9 | import controlimcap.models.graphattn 10 | import controlimcap.models.graphflow 11 | 12 | MPENCODER = 'mp_encoder' 13 | ATTNENCODER = 'attn_encoder' 14 | DECODER = 'decoder' 15 | 16 | 17 | class GraphBUTDMemoryModel(controlimcap.models.graphattn.GraphBUTDAttnModel): 18 | def build_submods(self): 19 | submods = {} 20 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 21 | submods[ATTNENCODER] =controlimcap.encoders.gcn.RGCNEncoder(self.config.subcfgs[ATTNENCODER]) 22 | submods[DECODER] = controlimcap.decoders.memory.MemoryDecoder(self.config.subcfgs[DECODER]) 23 | return submods 24 | 25 | class RoleGraphBUTDMemoryModel(controlimcap.models.graphattn.RoleGraphBUTDAttnModel): 26 | def build_submods(self): 27 | submods = {} 28 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 29 | submods[ATTNENCODER] =controlimcap.encoders.gcn.RoleRGCNEncoder(self.config.subcfgs[ATTNENCODER]) 30 | submods[DECODER] = controlimcap.decoders.memory.MemoryDecoder(self.config.subcfgs[DECODER]) 31 | return submods 32 | 33 | 34 | class GraphBUTDMemoryFlowModel(controlimcap.models.graphflow.GraphBUTDCFlowAttnModel): 35 | def build_submods(self): 36 | submods = {} 37 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 38 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RGCNEncoder(self.config.subcfgs[ATTNENCODER]) 39 | submods[DECODER] = controlimcap.decoders.memory.MemoryFlowDecoder(self.config.subcfgs[DECODER]) 40 | return submods 41 | 42 | class RoleGraphBUTDMemoryFlowModel(controlimcap.models.graphflow.RoleGraphBUTDCFlowAttnModel): 43 | def build_submods(self): 44 | submods = {} 45 | submods[MPENCODER] = caption.encoders.vanilla.Encoder(self.config.subcfgs[MPENCODER]) 46 | submods[ATTNENCODER] = controlimcap.encoders.gcn.RoleRGCNEncoder(self.config.subcfgs[ATTNENCODER]) 47 | submods[DECODER] = controlimcap.decoders.memory.MemoryFlowDecoder(self.config.subcfgs[DECODER]) 48 | return submods 49 | 50 | 51 | -------------------------------------------------------------------------------- /controlimcap/readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/controlimcap/readers/__init__.py -------------------------------------------------------------------------------- /controlimcap/readers/imgsgreader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import h5py 4 | import numpy as np 5 | from scipy import sparse 6 | import collections 7 | import math 8 | import torch 9 | 10 | import caption.readers.base 11 | 12 | NUM_RELS = 6 13 | UNK_WORDEMBED = np.zeros((300, ), dtype=np.float32) 14 | PIXEL_REDUCE = 1 15 | 16 | class ImageSceneGraphFlatReader(caption.readers.base.CaptionDatasetBase): 17 | def __init__(self, name_file, mp_ft_file, obj_ft_dir, region_anno_dir, 18 | word2int_file, max_attn_len=10, max_words_in_sent=15, 19 | is_train=False, return_label=False, _logger=None, 20 | pred_caption_file=None): 21 | 22 | super().__init__(word2int_file, max_words_in_sent=max_words_in_sent, 23 | is_train=is_train, return_label=return_label, _logger=_logger) 24 | 25 | if 'VisualGenome' in name_file: 26 | global PIXEL_REDUCE 27 | PIXEL_REDUCE = 0 28 | 29 | self.obj_ft_dir = obj_ft_dir 30 | self.max_attn_len = max_attn_len 31 | self.region_anno_dir = region_anno_dir 32 | 33 | img_names = np.load(name_file) 34 | self.img_id_to_ftidx_name = {x.split('.')[0]: (i, x) \ 35 | for i, x in enumerate(img_names)} 36 | 37 | self.mp_fts = np.load(mp_ft_file) 38 | self.print_fn('mp_fts %s'%(str(self.mp_fts.shape))) 39 | 40 | self.names = np.load(os.path.join(region_anno_dir, os.path.basename(name_file))) 41 | self.num_data = len(self.names) 42 | self.print_fn('num_data %d' % (self.num_data)) 43 | 44 | if pred_caption_file is None: 45 | self.pred_captions = None 46 | else: 47 | self.pred_captions = json.load(open(pred_caption_file)) 48 | 49 | def __getitem__(self, idx): 50 | image_id, region_id = self.names[idx] 51 | name = '%s_%s'%(image_id, region_id) 52 | 53 | anno = json.load(open(os.path.join(self.region_anno_dir, '%s.json'%image_id))) 54 | region_graph = anno[region_id] 55 | region_caption = anno[region_id]['phrase'] 56 | 57 | with h5py.File(os.path.join(self.obj_ft_dir, '%s.jpg.hdf5'%image_id.replace('/', '_')), 'r') as f: 58 | key = '%s.jpg'%image_id.replace('/', '_') 59 | obj_fts = f[key][...] 60 | obj_bboxes = f[key].attrs['boxes'] 61 | obj_box_to_ft = {tuple(box): ft for box, ft in zip(obj_bboxes, obj_fts)} 62 | 63 | attn_ft, node_types, attr_order_idxs = [], [], [] 64 | obj_id_to_box = {} 65 | for x in region_graph['objects']: 66 | box = (x['x'], x['y'], x['x']+x['w']-PIXEL_REDUCE, x['y']+x['h']-PIXEL_REDUCE) 67 | obj_id_to_box[x['object_id']] = box 68 | attn_ft.append(obj_box_to_ft[box]) 69 | attr_order_idxs.append(0) 70 | node_types.append(0) 71 | for ia, attr in enumerate(x['attributes']): 72 | attn_ft.append(obj_box_to_ft[box]) 73 | attr_order_idxs.append(ia + 1) 74 | node_types.append(1) 75 | 76 | for x in region_graph['relationships']: 77 | obj_box = obj_id_to_box[x['object_id']] 78 | subj_box = obj_id_to_box[x['subject_id']] 79 | box = (min(obj_box[0], subj_box[0]), min(obj_box[1], subj_box[1]), 80 | max(obj_box[2], subj_box[2]), max(obj_box[3], subj_box[3])) 81 | attn_ft.append(obj_box_to_ft[box]) 82 | node_types.append(2) 83 | attr_order_idxs.append(0) 84 | 85 | num_nodes = len(node_types) 86 | attn_ft, attn_mask = self.pad_or_trim_feature( 87 | np.array(attn_ft[:self.max_attn_len], np.float32), 88 | self.max_attn_len) 89 | node_types = node_types[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 90 | node_types = np.array(node_types, np.int32) 91 | attr_order_idxs = attr_order_idxs[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 92 | attr_order_idxs = np.array(attr_order_idxs, np.int32) 93 | 94 | out = { 95 | 'names': name, 96 | 'mp_fts': self.mp_fts[self.img_id_to_ftidx_name[image_id][0]], 97 | 'attn_fts': attn_ft, 98 | 'attn_masks': attn_mask, 99 | 'node_types': node_types, 100 | 'attr_order_idxs': attr_order_idxs, 101 | } 102 | if self.is_train or self.return_label: 103 | sent = region_caption 104 | caption_ids, caption_masks = self.pad_sents(self.sent2int(sent)) 105 | out.update({ 106 | 'caption_ids': caption_ids, 107 | 'caption_masks': caption_masks, 108 | 'ref_sents': [sent], 109 | }) 110 | return out 111 | 112 | def __len__(self): 113 | return self.num_data 114 | 115 | 116 | def flat_collate_fn(data): 117 | outs = {} 118 | for key in ['names', 'mp_fts', 'attn_fts', 'attn_masks', 119 | 'caption_ids', 'caption_masks', 'node_types', 'attr_order_idxs']: 120 | if key in data[0]: 121 | outs[key] = [x[key] for x in data] 122 | 123 | outs['mp_fts'] = np.array(outs['mp_fts']) 124 | max_attn_len = np.max(np.sum(outs['attn_masks'], 1)) 125 | outs['attn_fts'] = np.array(outs['attn_fts'])[:, :max_attn_len] 126 | outs['attn_masks'] = np.array(outs['attn_masks'])[:, :max_attn_len] 127 | 128 | for key in ['node_types', 'attr_order_idxs']: 129 | if key in data[0]: 130 | outs[key] = np.array(outs[key])[:, :max_attn_len] 131 | 132 | if 'caption_ids' in data[0]: 133 | outs['caption_ids'] = np.array(outs['caption_ids'], np.int32) 134 | outs['caption_masks'] = np.array(outs['caption_masks'], np.bool) 135 | max_sent_len = np.max(np.sum(outs['caption_masks'])) 136 | outs['caption_ids'] = outs['caption_ids'][:, :max_sent_len] 137 | outs['caption_masks'] = outs['caption_masks'][:, :max_sent_len] 138 | outs['ref_sents'] = {} 139 | for x in data: 140 | outs['ref_sents'][x['names']] = x['ref_sents'] 141 | 142 | return outs 143 | 144 | class ImageSceneGraphReader(ImageSceneGraphFlatReader): 145 | def add_obj_attr_edge(self, edges, obj_node_id, attr_node_id): 146 | edges.append([obj_node_id, attr_node_id, 0]) 147 | edges.append([attr_node_id, obj_node_id, 1]) 148 | 149 | def add_rel_subj_edge(self, edges, rel_node_id, subj_node_id): 150 | edges.append([subj_node_id, rel_node_id, 2]) 151 | edges.append([rel_node_id, subj_node_id, 3]) 152 | 153 | def add_rel_obj_edge(self, edges, rel_node_id, obj_node_i): 154 | edges.append([rel_node_id, obj_node_i, 4]) 155 | edges.append([obj_node_i, rel_node_id, 5]) 156 | 157 | def __getitem__(self, idx): 158 | image_id, region_id = self.names[idx] 159 | name = '%s_%s'%(image_id, region_id) 160 | anno = json.load(open(os.path.join(self.region_anno_dir, '%s.json'%image_id))) 161 | region_graph = anno[region_id] 162 | region_caption = anno[region_id]['phrase'] 163 | 164 | with h5py.File(os.path.join(self.obj_ft_dir, '%s.jpg.hdf5'%image_id.replace('/', '_')), 'r') as f: 165 | key = '%s.jpg'%image_id.replace('/', '_') 166 | obj_fts = f[key][...] 167 | obj_bboxes = f[key].attrs['boxes'] 168 | obj_box_to_ft = {tuple(box): ft for box, ft in zip(obj_bboxes, obj_fts)} 169 | 170 | attn_fts, node_types, attr_order_idxs, edges = [], [], [], [] 171 | obj_id_to_box = {} 172 | obj_id_to_graph_id = {} 173 | n = 0 174 | for x in region_graph['objects']: 175 | box = (x['x'], x['y'], x['x']+x['w']-PIXEL_REDUCE, x['y']+x['h']-PIXEL_REDUCE) 176 | obj_id_to_box[x['object_id']] = box 177 | attn_fts.append(obj_box_to_ft[box]) 178 | attr_order_idxs.append(0) 179 | node_types.append(0) 180 | obj_id_to_graph_id[x['object_id']] = n 181 | n += 1 182 | if n >= self.max_attn_len: 183 | break 184 | for ia, attr in enumerate(x['attributes']): 185 | attn_fts.append(obj_box_to_ft[box]) 186 | attr_order_idxs.append(ia + 1) 187 | node_types.append(1) 188 | self.add_obj_attr_edge(edges, obj_id_to_graph_id[x['object_id']], n) 189 | n += 1 190 | if n >= self.max_attn_len: 191 | break 192 | if n >= self.max_attn_len: 193 | break 194 | 195 | if n < self.max_attn_len: 196 | for x in region_graph['relationships']: 197 | obj_box = obj_id_to_box[x['object_id']] 198 | subj_box = obj_id_to_box[x['subject_id']] 199 | box = (min(obj_box[0], subj_box[0]), min(obj_box[1], subj_box[1]), 200 | max(obj_box[2], subj_box[2]), max(obj_box[3], subj_box[3])) 201 | attn_fts.append(obj_box_to_ft[box]) 202 | attr_order_idxs.append(0) 203 | node_types.append(2) 204 | self.add_rel_subj_edge(edges, n, obj_id_to_graph_id[x['subject_id']]) 205 | self.add_rel_obj_edge(edges, n, obj_id_to_graph_id[x['object_id']]) 206 | n += 1 207 | if n >= self.max_attn_len: 208 | break 209 | 210 | num_nodes = len(node_types) 211 | attn_fts = np.array(attn_fts, np.float32) 212 | attn_fts, attn_masks = self.pad_or_trim_feature(attn_fts, self.max_attn_len) 213 | node_types = node_types[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 214 | node_types = np.array(node_types, np.int32) 215 | attr_order_idxs = attr_order_idxs[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 216 | attr_order_idxs = np.array(attr_order_idxs, np.int32) 217 | 218 | if len(edges) > 0: 219 | src_nodes, tgt_nodes, edge_types = tuple(zip(*edges)) 220 | src_nodes = np.array(src_nodes, np.int32) 221 | tgt_nodes = np.array(tgt_nodes, np.int32) 222 | edge_types = np.array(edge_types, np.int32) 223 | edge_counter = collections.Counter([(tgt_node, edge_type) for tgt_node, edge_type in zip(tgt_nodes, edge_types)]) 224 | edge_norms = np.array( 225 | [1 / edge_counter[(tgt_node, edge_type)] for tgt_node, edge_type in zip(tgt_nodes, edge_types)], 226 | np.float32) 227 | else: 228 | tgt_nodes = src_nodes = edge_types = edge_norms = np.array([]) 229 | 230 | edge_sparse_matrices = [] 231 | for i in range(NUM_RELS): 232 | idxs = (edge_types == i) 233 | edge_sparse_matrices.append( 234 | sparse.coo_matrix((edge_norms[idxs], (tgt_nodes[idxs], src_nodes[idxs])), 235 | shape=(self.max_attn_len, self.max_attn_len))) 236 | 237 | out = { 238 | 'names': name, 239 | 'mp_fts': self.mp_fts[self.img_id_to_ftidx_name[image_id][0]], 240 | 'attn_fts': attn_fts, 241 | 'attn_masks': attn_masks, 242 | 'node_types': node_types, 243 | 'attr_order_idxs': attr_order_idxs, 244 | 'edge_sparse_matrices': edge_sparse_matrices, 245 | } 246 | if self.is_train or self.return_label: 247 | sent = region_caption 248 | caption_ids, caption_masks = self.pad_sents(self.sent2int(sent)) 249 | out.update({ 250 | 'caption_ids': caption_ids, 251 | 'caption_masks': caption_masks, 252 | 'ref_sents': [sent], 253 | }) 254 | return out 255 | 256 | 257 | def sg_sparse_collate_fn(data): 258 | outs = {} 259 | for key in ['names', 'mp_fts', 'attn_fts', 'attn_masks', 'node_types', 'attr_order_idxs', \ 260 | 'edge_sparse_matrices', 'caption_ids', 'caption_masks']: 261 | if key in data[0]: 262 | outs[key] = [x[key] for x in data] 263 | 264 | outs['mp_fts'] = np.array(outs['mp_fts']) 265 | max_attn_len, dim_attn_ft = data[0]['attn_fts'].shape 266 | # (batch, max_attn_len, dim_attn_ft) 267 | outs['attn_fts'] = np.array(outs['attn_fts']) 268 | outs['attn_masks'] = np.array(outs['attn_masks']) 269 | 270 | if 'caption_ids' in data[0]: 271 | outs['caption_ids'] = np.array(outs['caption_ids'], np.int32) 272 | outs['caption_masks'] = np.array(outs['caption_masks'], np.bool) 273 | max_sent_len = np.max(np.sum(outs['caption_masks'])) 274 | outs['caption_ids'] = outs['caption_ids'][:, :max_sent_len] 275 | outs['caption_masks'] = outs['caption_masks'][:, :max_sent_len] 276 | outs['ref_sents'] = {} 277 | for x in data: 278 | outs['ref_sents'][x['names']] = x['ref_sents'] 279 | 280 | return outs 281 | 282 | 283 | class ImageSceneGraphFlowReader(ImageSceneGraphReader): 284 | def __getitem__(self, idx): 285 | image_id, region_id = self.names[idx] 286 | name = '%s_%s'%(image_id, region_id) 287 | 288 | anno = json.load(open(os.path.join(self.region_anno_dir, '%s.json'%image_id))) 289 | region_graph = anno[region_id] 290 | if self.pred_captions is not None: 291 | region_caption = self.pred_captions[name][0] 292 | else: 293 | region_caption = anno[region_id]['phrase'] 294 | 295 | with h5py.File(os.path.join(self.obj_ft_dir, '%s.jpg.hdf5'%image_id.replace('/', '_')), 'r') as f: 296 | key = '%s.jpg'%image_id.replace('/', '_') 297 | obj_fts = f[key][...] 298 | obj_bboxes = f[key].attrs['boxes'] 299 | obj_box_to_ft = {tuple(box): ft for box, ft in zip(obj_bboxes, obj_fts)} 300 | 301 | attn_fts, node_types, attr_order_idxs = [], [], [] 302 | attn_node_names = [] 303 | edges, flow_edges = [], [] 304 | obj_id_to_box = {} 305 | obj_id_to_graph_id = {} 306 | n = 0 307 | for x in region_graph['objects']: 308 | box = (x['x'], x['y'], x['x']+x['w']-PIXEL_REDUCE, x['y']+x['h']-PIXEL_REDUCE) 309 | obj_id_to_box[x['object_id']] = box 310 | attn_fts.append(obj_box_to_ft[box]) 311 | attn_node_names.append(x['name']) 312 | attr_order_idxs.append(0) 313 | node_types.append(0) 314 | obj_id_to_graph_id[x['object_id']] = n 315 | n += 1 316 | if n >= self.max_attn_len: 317 | break 318 | for ia, attr in enumerate(x['attributes']): 319 | attn_fts.append(obj_box_to_ft[box]) 320 | attn_node_names.append(attr) 321 | attr_order_idxs.append(ia + 1) 322 | node_types.append(1) 323 | self.add_obj_attr_edge(edges, obj_id_to_graph_id[x['object_id']], n) 324 | # bi-directional for obj-attr 325 | flow_edges.append((obj_id_to_graph_id[x['object_id']], n)) 326 | flow_edges.append((n, obj_id_to_graph_id[x['object_id']])) 327 | n += 1 328 | if n >= self.max_attn_len: 329 | break 330 | if n >= self.max_attn_len: 331 | break 332 | 333 | if n < self.max_attn_len: 334 | for x in region_graph['relationships']: 335 | obj_box = obj_id_to_box[x['object_id']] 336 | subj_box = obj_id_to_box[x['subject_id']] 337 | box = (min(obj_box[0], subj_box[0]), min(obj_box[1], subj_box[1]), 338 | max(obj_box[2], subj_box[2]), max(obj_box[3], subj_box[3])) 339 | attn_fts.append(obj_box_to_ft[box]) 340 | attn_node_names.append(x['name']) 341 | attr_order_idxs.append(0) 342 | node_types.append(2) 343 | self.add_rel_subj_edge(edges, n, obj_id_to_graph_id[x['subject_id']]) 344 | self.add_rel_obj_edge(edges, n, obj_id_to_graph_id[x['object_id']]) 345 | flow_edges.append((obj_id_to_graph_id[x['subject_id']], n)) 346 | flow_edges.append((n, obj_id_to_graph_id[x['object_id']])) 347 | n += 1 348 | if n >= self.max_attn_len: 349 | break 350 | 351 | num_nodes = len(node_types) 352 | attn_fts = np.array(attn_fts, np.float32) 353 | attn_fts, attn_masks = self.pad_or_trim_feature(attn_fts, self.max_attn_len) 354 | node_types = node_types[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 355 | node_types = np.array(node_types, np.int32) 356 | attr_order_idxs = attr_order_idxs[:self.max_attn_len] + [0] * max(0, self.max_attn_len - num_nodes) 357 | attr_order_idxs = np.array(attr_order_idxs, np.int32) 358 | 359 | if len(edges) > 0: 360 | src_nodes, tgt_nodes, edge_types = tuple(zip(*edges)) 361 | src_nodes = np.array(src_nodes, np.int32) 362 | tgt_nodes = np.array(tgt_nodes, np.int32) 363 | edge_types = np.array(edge_types, np.int32) 364 | edge_counter = collections.Counter([(tgt_node, edge_type) for tgt_node, edge_type in zip(tgt_nodes, edge_types)]) 365 | edge_norms = np.array( 366 | [1 / edge_counter[(tgt_node, edge_type)] for tgt_node, edge_type in zip(tgt_nodes, edge_types)], 367 | np.float32) 368 | else: 369 | tgt_nodes = src_nodes = edge_types = edge_norms = np.array([]) 370 | 371 | # build python sparse matrix 372 | edge_sparse_matrices = [] 373 | for i in range(NUM_RELS): 374 | idxs = (edge_types == i) 375 | edge_sparse_matrices.append( 376 | sparse.coo_matrix((edge_norms[idxs], (tgt_nodes[idxs], src_nodes[idxs])), 377 | shape=(self.max_attn_len, self.max_attn_len))) 378 | 379 | # add end flow loop 380 | flow_src_nodes = set([x[0] for x in flow_edges]) 381 | for k in range(n): 382 | if k not in flow_src_nodes: 383 | flow_edges.append((k, k)) # end loop 384 | # flow order graph 385 | flow_src_nodes, flow_tgt_nodes = tuple(zip(*flow_edges)) 386 | flow_src_nodes = np.array(flow_src_nodes, np.int32) 387 | flow_tgt_nodes = np.array(flow_tgt_nodes, np.int32) 388 | # normalize by src (collumn) 389 | flow_counter = collections.Counter(flow_src_nodes) 390 | flow_edge_norms = np.array( 391 | [1 / flow_counter[src_node] for src_node in flow_src_nodes]) 392 | 393 | flow_sparse_matrix = sparse.coo_matrix((flow_edge_norms, (flow_tgt_nodes, flow_src_nodes)), 394 | shape=(self.max_attn_len, self.max_attn_len)) 395 | 396 | out = { 397 | 'names': name, 398 | 'mp_fts': self.mp_fts[self.img_id_to_ftidx_name[image_id][0]], 399 | 'attn_fts': attn_fts, 400 | 'attn_masks': attn_masks, 401 | 'node_types': node_types, 402 | 'attr_order_idxs': attr_order_idxs, 403 | 'edge_sparse_matrices': edge_sparse_matrices, 404 | 'flow_sparse_matrix': flow_sparse_matrix, 405 | } 406 | if self.is_train or self.return_label: 407 | sent = region_caption 408 | caption_ids, caption_masks = self.pad_sents(self.sent2int(sent)) 409 | out.update({ 410 | 'caption_ids': caption_ids, 411 | 'caption_masks': caption_masks, 412 | 'ref_sents': [sent], 413 | 'attn_node_names': attn_node_names, 414 | }) 415 | return out 416 | 417 | 418 | def sg_sparse_flow_collate_fn(data): 419 | outs = sg_sparse_collate_fn(data) 420 | outs['flow_sparse_matrix'] = [x['flow_sparse_matrix'] for x in data] 421 | return outs 422 | -------------------------------------------------------------------------------- /figures/method_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/figures/method_framework.png -------------------------------------------------------------------------------- /figures/user_intention_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/figures/user_intention_examples.png -------------------------------------------------------------------------------- /framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/framework/__init__.py -------------------------------------------------------------------------------- /framework/configbase.py: -------------------------------------------------------------------------------- 1 | import json 2 | import enum 3 | 4 | import numpy as np 5 | 6 | class ModuleConfig(object): 7 | """config of a module 8 | basic attributes: 9 | [freeze] boolean, whether to freeze the weights in this module in training. 10 | [lr_mult] float, the multiplier to the base learning rate for weights in this modules. 11 | [opt_alg] string, 'Adam|SGD|RMSProp', optimizer 12 | """ 13 | def __init__(self): 14 | self.freeze = False 15 | self.lr_mult = 1.0 16 | self.opt_alg = 'Adam' 17 | self.weight_decay = 0 18 | 19 | def load_from_dict(self, cfg_dict): 20 | for key, value in cfg_dict.items(): 21 | if key in self.__dict__: 22 | setattr(self, key, value) 23 | self._assert() 24 | 25 | def save_to_dict(self): 26 | out = {} 27 | for attr in self.__dict__: 28 | val = self.__dict__[attr] 29 | out[attr] = val 30 | return out 31 | 32 | def _assert(self): 33 | """check compatibility between configs 34 | """ 35 | # raise NotImplementedError("""please customize %s._assert"""%(self.__class__.__name__)) 36 | pass 37 | 38 | 39 | class ModelConfig(object): 40 | def __init__(self): 41 | self.subcfgs = {} # save configure of submodules 42 | 43 | self.trn_batch_size = 128 44 | self.tst_batch_size = 128 45 | self.num_epoch = 100 46 | self.val_per_epoch = True 47 | self.save_per_epoch = True 48 | self.val_iter = -1 49 | self.save_iter = -1 50 | self.monitor_iter = -1 51 | self.summary_iter = -1 # tensorboard summary 52 | 53 | self.base_lr = 1e-4 54 | self.decay_schema = None #'MultiStepLR' 55 | self.decay_boundarys = [] 56 | self.decay_rate = 1 57 | 58 | def load(self, cfg_file): 59 | with open(cfg_file) as f: 60 | data = json.load(f) 61 | for key, value in data.items(): 62 | if key == 'subcfgs': 63 | for subname, subcfg in data[key].items(): 64 | self.subcfgs[subname].load_from_dict(subcfg) 65 | else: 66 | setattr(self, key, value) 67 | 68 | def save(self, out_file): 69 | out = {} 70 | for key in self.__dict__: 71 | if key == 'subcfgs': 72 | out['subcfgs'] = {} 73 | for subname, subcfg in self.__dict__['subcfgs'].items(): 74 | out['subcfgs'][subname] = subcfg.save_to_dict() 75 | else: 76 | out[key] = self.__dict__[key] 77 | with open(out_file, 'w') as f: 78 | json.dump(out, f, indent=2) 79 | 80 | 81 | class PathCfg(object): 82 | def __init__(self): 83 | self.log_dir = '' 84 | self.model_dir = '' 85 | self.pred_dir = '' 86 | 87 | self.log_file = '' 88 | self.val_metric_file = '' 89 | self.model_file = '' 90 | self.predict_file = '' 91 | 92 | def load(self, config_dict): 93 | for key, value in config_dict.items(): 94 | setattr(self, key, value) 95 | 96 | def save(self, output_path): 97 | data = {} 98 | for key in self.__dict__: 99 | data[key] = self.__getattribute__(key) 100 | with open(output_path, 'w') as f: 101 | json.dump(data, f, indent=2) 102 | -------------------------------------------------------------------------------- /framework/logbase.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import numpy as np 5 | 6 | def set_logger(log_path, log_name='training'): 7 | if log_path is None: 8 | print('log_path is empty') 9 | return None 10 | 11 | if os.path.exists(log_path): 12 | print('%s already exists'%log_path) 13 | return None 14 | 15 | logger = logging.getLogger(log_name) 16 | logger.setLevel(logging.DEBUG) 17 | 18 | logfile = logging.FileHandler(log_path) 19 | console = logging.StreamHandler() 20 | logfile.setLevel(logging.INFO) 21 | logfile.setFormatter(logging.Formatter('%(asctime)s %(message)s')) 22 | console.setLevel(logging.DEBUG) 23 | console.setFormatter(logging.Formatter('%(asctime)s %(message)s')) 24 | logger.addHandler(logfile) 25 | logger.addHandler(console) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /framework/modelbase.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | import torch.nn.functional as F 10 | 11 | import framework.logbase 12 | 13 | 14 | class ModelBase(object): 15 | def __init__(self, config, _logger=None, gpu_id=0): 16 | '''initialize model 17 | (support single GPU, otherwise need to be customized) 18 | ''' 19 | self.device = torch.device("cuda:%d"%gpu_id if torch.cuda.is_available() else "cpu") 20 | self.config = config 21 | if _logger is None: 22 | self.print_fn = print 23 | else: 24 | self.print_fn = _logger.info 25 | 26 | self.submods = self.build_submods() 27 | for submod in self.submods.values(): 28 | submod.to(self.device) 29 | self.criterion = self.build_loss() 30 | self.params, self.optimizer, self.lr_scheduler = self.build_optimizer() 31 | 32 | num_params, num_weights = 0, 0 33 | for key, submod in self.submods.items(): 34 | for varname, varvalue in submod.state_dict().items(): 35 | self.print_fn('%s: %s, shape=%s, num:%d' % ( 36 | key, varname, str(varvalue.size()), np.prod(varvalue.size()))) 37 | num_params += 1 38 | num_weights += np.prod(varvalue.size()) 39 | self.print_fn('num params %d, num weights %d'%(num_params, num_weights)) 40 | self.print_fn('trainable: num params %d, num weights %d'%( 41 | len(self.params), sum([np.prod(param.size()) for param in self.params]))) 42 | 43 | def build_submods(self): 44 | raise NotImplementedError('implement build_submods function: return submods') 45 | 46 | def build_loss(self): 47 | raise NotImplementedError('implement build_loss function: return criterion') 48 | 49 | def forward_loss(self, batch_data, step=None): 50 | raise NotImplementedError('implement forward_loss function: return loss and additional outs') 51 | 52 | def validate(self, val_reader, step=None): 53 | self.eval_start() 54 | # raise NotImplementedError('implement validate function: return metrics') 55 | 56 | def test(self, tst_reader, tst_pred_file, tst_model_file=None): 57 | if tst_model_file is not None: 58 | self.load_checkpoint(tst_model_file) 59 | self.eval_start() 60 | # raise NotImplementedError('implement test function') 61 | 62 | ########################## boilerpipe functions ######################## 63 | def build_optimizer(self): 64 | trn_params = [] 65 | trn_param_ids = set() 66 | per_param_opts = [] 67 | for key, submod in self.submods.items(): 68 | if self.config.subcfgs[key].freeze: 69 | for param in submod.parameters(): 70 | param.requires_grad = False 71 | else: 72 | params = [] 73 | for param in submod.parameters(): 74 | # sometimes we share params in different submods 75 | if param.requires_grad and id(param) not in trn_param_ids: 76 | params.append(param) 77 | trn_param_ids.add(id(param)) 78 | per_param_opts.append({ 79 | 'params': params, 80 | 'lr': self.config.base_lr * self.config.subcfgs[key].lr_mult, 81 | 'weight_decay': self.config.subcfgs[key].weight_decay, 82 | }) 83 | trn_params.extend(params) 84 | if len(trn_params) > 0: 85 | optimizer = optim.Adam(per_param_opts, lr=self.config.base_lr) 86 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 87 | milestones=self.config.decay_boundarys, gamma=self.config.decay_rate) 88 | else: 89 | optimizer, lr_scheduler = None, None 90 | print('no traiable parameters') 91 | return trn_params, optimizer, lr_scheduler 92 | 93 | def train_start(self): 94 | for key, submod in self.submods.items(): 95 | submod.train() 96 | torch.set_grad_enabled(True) 97 | 98 | def eval_start(self): 99 | for key, submod in self.submods.items(): 100 | submod.eval() 101 | torch.set_grad_enabled(False) 102 | 103 | def save_checkpoint(self, ckpt_file, submods=None): 104 | if submods is None: 105 | submods = self.submods 106 | state_dicts = {} 107 | for key, submod in submods.items(): 108 | state_dicts[key] = {} 109 | for varname, varvalue in submod.state_dict().items(): 110 | state_dicts[key][varname] = varvalue.cpu() 111 | torch.save(state_dicts, ckpt_file) 112 | 113 | def load_checkpoint(self, ckpt_file, submods=None): 114 | if submods is None: 115 | submods = self.submods 116 | state_dicts = torch.load(ckpt_file, map_location=lambda storage, loc: storage) 117 | 118 | num_resumed_vars = 0 119 | for key, state_dict in state_dicts.items(): 120 | if key in submods: 121 | own_state_dict = submods[key].state_dict() 122 | new_state_dict = {} 123 | for varname, varvalue in state_dict.items(): 124 | if varname in own_state_dict: 125 | new_state_dict[varname] = varvalue 126 | num_resumed_vars += 1 127 | own_state_dict.update(new_state_dict) 128 | submods[key].load_state_dict(own_state_dict) 129 | self.print_fn('number of resumed variables: %d'%num_resumed_vars) 130 | 131 | def pretty_print_metrics(self, prefix, metrics): 132 | metric_str = [] 133 | for measure, score in metrics.items(): 134 | metric_str.append('%s %.4f'%(measure, score)) 135 | metric_str = ' '.join(metric_str) 136 | self.print_fn('%s: %s' % (prefix, metric_str)) 137 | 138 | def get_current_base_lr(self): 139 | return self.optimizer.param_groups[0]['lr'] 140 | 141 | def train_one_batch(self, batch_data, step): 142 | self.optimizer.zero_grad() 143 | loss = self.forward_loss(batch_data, step=step) 144 | loss.backward() 145 | self.optimizer.step() 146 | 147 | loss_value = loss.data.item() 148 | if step is not None and self.config.monitor_iter > 0 and step % self.config.monitor_iter == 0: 149 | self.print_fn('\ttrn step %d lr %.8f %s: %.4f' % (step, self.get_current_base_lr(), 'loss', loss_value)) 150 | return {'loss': loss_value} 151 | 152 | def train_one_epoch(self, step, trn_reader, val_reader, model_dir, log_dir): 153 | self.train_start() 154 | 155 | avg_loss, n_batches = {}, {} 156 | for batch_data in trn_reader: 157 | loss = self.train_one_batch(batch_data, step) 158 | for loss_key, loss_value in loss.items(): 159 | avg_loss.setdefault(loss_key, 0) 160 | n_batches.setdefault(loss_key, 0) 161 | avg_loss[loss_key] += loss_value 162 | n_batches[loss_key] += 1 163 | step += 1 164 | 165 | if self.config.save_iter > 0 and step % self.config.save_iter == 0: 166 | self.save_checkpoint(os.path.join(model_dir, 'step.%d.th'%step)) 167 | 168 | if (self.config.save_iter > 0 and step % self.config.save_iter == 0) \ 169 | or (self.config.val_iter > 0 and step % self.config.val_iter == 0): 170 | metrics = self.validate(val_reader, step=step) 171 | with open(os.path.join(log_dir, 'val.step.%d.json'%step), 'w') as f: 172 | json.dump(metrics, f, indent=2) 173 | self.pretty_print_metrics('\tval step %d'%step, metrics) 174 | self.train_start() 175 | 176 | for loss_key, loss_value in avg_loss.items(): 177 | avg_loss[loss_key] = loss_value / n_batches[loss_key] 178 | return avg_loss, step 179 | 180 | def epoch_postprocess(self, epoch): 181 | if self.lr_scheduler is not None: 182 | self.lr_scheduler.step() 183 | 184 | def train(self, trn_reader, val_reader, model_dir, log_dir, resume_file=None): 185 | assert self.optimizer is not None 186 | 187 | if resume_file is not None: 188 | self.load_checkpoint(resume_file) 189 | 190 | # first validate 191 | metrics = self.validate(val_reader) 192 | self.pretty_print_metrics('init val', metrics) 193 | 194 | # training 195 | step = 0 196 | for epoch in range(self.config.num_epoch): 197 | avg_loss, step = self.train_one_epoch( 198 | step, trn_reader, val_reader, model_dir, log_dir) 199 | self.pretty_print_metrics('epoch (%d/%d) trn'%(epoch, self.config.num_epoch), avg_loss) 200 | self.epoch_postprocess(epoch) 201 | 202 | if self.config.save_per_epoch: 203 | self.save_checkpoint(os.path.join(model_dir, 'epoch.%d.th'%epoch)) 204 | 205 | if self.config.val_per_epoch: 206 | metrics = self.validate(val_reader, step=step) 207 | with open(os.path.join(log_dir, 208 | 'val.epoch.%d.step.%d.json'%(epoch, step)), 'w') as f: 209 | json.dump(metrics, f, indent=2) 210 | self.pretty_print_metrics('epoch (%d/%d) val' % (epoch, self.config.num_epoch), metrics) 211 | 212 | -------------------------------------------------------------------------------- /framework/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshizhe/asg2cap/9d3a8a6312935cb1b033835c1edc522dcf9f6061/framework/modules/__init__.py -------------------------------------------------------------------------------- /framework/modules/embeddings.py: -------------------------------------------------------------------------------- 1 | """ Embeddings module """ 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """ 10 | Implements the sinusoidal positional encoding for 11 | non-recurrent neural networks. 12 | 13 | Implementation based on "Attention Is All You Need" 14 | 15 | Args: 16 | dim_embed (int): embedding size (even number) 17 | """ 18 | 19 | def __init__(self, dim_embed, max_len=100): 20 | super(PositionalEncoding, self).__init__() 21 | 22 | pe = torch.zeros(max_len, dim_embed) 23 | position = torch.arange(0, max_len).unsqueeze(1) 24 | div_term = torch.exp((torch.arange(0, dim_embed, 2, dtype=torch.float) * 25 | -(math.log(10000.0) / dim_embed))) 26 | pe[:, 0::2] = torch.sin(position.float() * div_term) 27 | pe[:, 1::2] = torch.cos(position.float() * div_term) 28 | 29 | self.pe = pe # size=(max_len, dim_embed) 30 | self.dim_embed = dim_embed 31 | 32 | def forward(self, emb, step=None): 33 | if emb.device != self.pe.device: 34 | self.pe = self.pe.to(emb.device) 35 | if step is None: 36 | # emb.size = (batch, seq_len, dim_embed) 37 | emb = emb + self.pe[:emb.size(1)] 38 | else: 39 | # emb.size = (batch, dim_embed) 40 | emb = emb + self.pe[step] 41 | return emb 42 | 43 | 44 | class Embedding(nn.Module): 45 | """Words embeddings for encoder/decoder. 46 | Args: 47 | word_vec_size (int): size of the dictionary of embeddings. 48 | word_vocab_size (int): size of dictionary of embeddings for words. 49 | position_encoding (bool): see :obj:`modules.PositionalEncoding` 50 | """ 51 | def __init__(self, word_vocab_size, word_vec_size, 52 | position_encoding=False, fix_word_embed=False, max_len=100): 53 | super(Embedding, self).__init__() 54 | 55 | self.word_vec_size = word_vec_size 56 | self.we = nn.Embedding(word_vocab_size, word_vec_size) 57 | if fix_word_embed: 58 | self.we.weight.requires_grad = False 59 | self.init_weight() 60 | 61 | self.position_encoding = position_encoding 62 | if self.position_encoding: 63 | self.pe = PositionalEncoding(word_vec_size, max_len=max_len) 64 | 65 | def init_weight(self): 66 | std = 1. / (self.word_vec_size**0.5) 67 | nn.init.uniform_(self.we.weight, -std, std) 68 | 69 | def forward(self, word_idxs, step=None): 70 | """Computes the embeddings for words. 71 | Args: 72 | word_idxs (`LongTensor`): index tensor 73 | size = (batch, seq_len) or (batch, ) 74 | Return: 75 | embeds: `FloatTensor`, 76 | size = (batch, seq_len, dim_embed) or (batch, dim_embed) 77 | """ 78 | embeds = self.we(word_idxs) 79 | if self.position_encoding: 80 | embeds = self.pe(embeds, step=step) 81 | return embeds 82 | -------------------------------------------------------------------------------- /framework/modules/global_attention.py: -------------------------------------------------------------------------------- 1 | """ Global attention modules (Luong / Bahdanau) """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class GlobalAttention(nn.Module): 7 | ''' 8 | Global attention takes a matrix and a query vector. It 9 | then computes a parameterized convex combination of the matrix 10 | based on the input query. 11 | 12 | Constructs a unit mapping a query `q` of size `dim` 13 | and a source matrix `H` of size `n x dim`, 14 | to an output of size `dim`. 15 | 16 | All models compute the output as 17 | :math:`c = sum_{j=1}^{SeqLength} a_j H_j` where 18 | :math:`a_j` is the softmax of a score function. 19 | 20 | However they differ on how they compute the attention score. 21 | 22 | * Luong Attention (dot, general): 23 | * dot: :math:`score(H_j,q) = H_j^T q` 24 | * general: :math:`score(H_j, q) = H_j^T W_a q` 25 | 26 | * Bahdanau Attention (mlp): 27 | * :math:`score(H_j, q) = w_a^T tanh(W_a q + U_a h_j)` 28 | 29 | Args: 30 | attn_size (int): dimensionality of query and key 31 | attn_type (str): type of attention to use, options [dot,general,mlp] 32 | ''' 33 | 34 | def __init__(self, query_size, attn_size, attn_type='dot'): 35 | super(GlobalAttention, self).__init__() 36 | 37 | self.query_size = query_size 38 | self.attn_size = attn_size 39 | self.attn_type = attn_type 40 | 41 | if self.attn_type == 'general': 42 | self.linear_in = nn.Linear(query_size, attn_size, bias=False) 43 | elif self.attn_type == 'mlp': 44 | self.linear_query = nn.Linear(query_size, attn_size, bias=True) 45 | self.attn_w = nn.Linear(attn_size, 1, bias=False) 46 | elif self.attn_type == 'dot': 47 | assert self.query_size == self.attn_size 48 | 49 | def forward(self, query, memory_keys, memory_values, memory_masks): 50 | """ 51 | Args: 52 | query (`FloatTensor`): (batch, query_size) 53 | memory_keys (`FloatTensor`): (batch, seq_len, attn_size) 54 | memory_values (`FloatTensor`): (batch, seq_len, attn_size) 55 | memory_masks (`LongTensor`): (batch, seq_len) 56 | 57 | Returns: 58 | attn_score: attention distributions (batch, seq_len) 59 | attn_memory: computed context vector, (batch, attn_size) 60 | """ 61 | batch_size, seq_len, attn_size = memory_keys.size() 62 | 63 | if self.attn_type == 'mlp': 64 | query_hidden = self.linear_query(query.unsqueeze(1)).expand( 65 | batch_size, seq_len, attn_size) 66 | # attn_hidden: # (batch, seq_len, attn_size) 67 | attn_hidden = torch.tanh(query_hidden + memory_keys) 68 | # attn_score: (batch, seq_len, 1) 69 | attn_score = self.attn_w(attn_hidden) 70 | elif self.attn_type == 'dot': 71 | # attn_score: (batch, seq_len, 1) 72 | attn_score = torch.bmm(memory_keys, query.unsqueeze(2)) 73 | elif self.attn_type == 'general': 74 | query_hidden = self.linear_in(query) 75 | attn_score = torch.bmm(memory_keys, query_hidden.unsqueeze(2)) 76 | 77 | # attn_score: (batch, seq_len) 78 | attn_score = attn_score.squeeze(2) 79 | if memory_masks is not None: 80 | attn_score = attn_score * memory_masks # memory mask [0, 1] 81 | attn_score = attn_score.masked_fill(memory_masks == 0, -1e18) 82 | attn_score = F.softmax(attn_score, dim=1) 83 | # make sure no item is attended when all memory_masks are all zeros 84 | if memory_masks is not None: 85 | attn_score = attn_score.masked_fill(memory_masks == 0, 0) 86 | attn_memory = torch.sum(attn_score.unsqueeze(2) * memory_values, 1) 87 | return attn_score, attn_memory 88 | 89 | #TODO 90 | class AdaptiveAttention(nn.Module): 91 | def __init__(self, query_size, attn_size): 92 | super(AdaptiveAttention, self).__init__() 93 | self.query_size = query_size 94 | self.attn_size = attn_size 95 | 96 | self.query_attn_conv = nn.Conv1d(query_size, 97 | attn_size, kernel_size=1, stride=1, padding=0, bias=True) 98 | self.sentinel_attn_conv = nn.Conv1d(query_size, 99 | attn_size, kernel_size=1, stride=1, padding=0, bias=False) 100 | self.attn_w = nn.Conv1d(attn_size, 1, kernel_size=1, 101 | stride=1, padding=0, bias=False) 102 | 103 | def forward(self, query, memory_keys, memory_values, memory_masks, sentinel): 104 | batch_size, _, enc_seq_len = memory_keys.size() 105 | 106 | query_hidden = self.query_attn_conv(query.unsqueeze(2)) 107 | sentinel_hidden = self.sentinel_attn_conv(sentinel.unsqueeze(2)) 108 | 109 | memory_keys_sentinel = torch.cat([memory_keys, sentinel_hidden], dim=2) 110 | attn_score = self.attn_w(F.tanh(query_hidden + memory_keys_sentinel)).squeeze(1) 111 | attn_score = F.softmax(attn_score, dim=1) 112 | masks = torch.cat([memory_masks, torch.ones(batch_size, 1).to(memory_masks.device)], dim=1) 113 | attn_score = attn_score * masks 114 | attn_score = attn_score / (torch.sum(attn_score, 1, keepdim=True) + 1e-10) 115 | attn_memory = torch.sum(attn_score[:, :-1].unsqueeze(1) * memory_values, 2) 116 | attn_memory = attn_memory + attn_score[:, -1:] * sentinel 117 | return attn_score, attn_memory 118 | -------------------------------------------------------------------------------- /framework/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | from torch.nn.utils.rnn import pad_packed_sequence 9 | 10 | def l2norm(inputs, dim=-1): 11 | # inputs: (batch, dim_ft) 12 | norm = torch.norm(inputs, p=2, dim=dim, keepdim=True) 13 | inputs = inputs / norm.clamp(min=1e-10) 14 | return inputs 15 | 16 | def sequence_mask(lengths, max_len=None, inverse=False): 17 | ''' Creates a boolean mask from sequence lengths. 18 | ''' 19 | # lengths: LongTensor, (batch, ) 20 | batch_size = lengths.size(0) 21 | max_len = max_len or lengths.max() 22 | mask = torch.arange(0, max_len).type_as(lengths).repeat(batch_size, 1) 23 | if inverse: 24 | mask = mask.ge(lengths.unsqueeze(1)) 25 | else: 26 | mask = mask.lt(lengths.unsqueeze(1)) 27 | return mask 28 | 29 | def subsequent_mask(size): 30 | '''Mask out subsequent position. 31 | Args 32 | size: the length of tgt words''' 33 | attn_shape = (1, size, size) 34 | # set the values below the 1th diagnose as 0 35 | mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 36 | mask = torch.from_numpy(mask) == 0 37 | return mask 38 | 39 | def rnn_factory(rnn_type, **kwargs): 40 | rnn = getattr(nn, rnn_type.upper())(**kwargs) 41 | return rnn 42 | 43 | -------------------------------------------------------------------------------- /framework/run_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datetime 4 | import numpy as np 5 | import glob 6 | 7 | import framework.configbase 8 | 9 | 10 | def gen_common_pathcfg(path_cfg_file, is_train=False): 11 | path_cfg = framework.configbase.PathCfg() 12 | path_cfg.load(json.load(open(path_cfg_file))) 13 | 14 | output_dir = path_cfg.output_dir 15 | 16 | path_cfg.log_dir = os.path.join(output_dir, 'log') 17 | path_cfg.model_dir = os.path.join(output_dir, 'model') 18 | path_cfg.pred_dir = os.path.join(output_dir, 'pred') 19 | if not os.path.exists(path_cfg.log_dir): 20 | os.makedirs(path_cfg.log_dir) 21 | if not os.path.exists(path_cfg.model_dir): 22 | os.makedirs(path_cfg.model_dir) 23 | if not os.path.exists(path_cfg.pred_dir): 24 | os.makedirs(path_cfg.pred_dir) 25 | 26 | if is_train: 27 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 28 | path_cfg.log_file = os.path.join(path_cfg.log_dir, 'log-' + timestamp) 29 | else: 30 | path_cfg.log_file = None 31 | 32 | return path_cfg 33 | 34 | 35 | def find_best_val_models(log_dir, model_dir): 36 | step_jsons = glob.glob(os.path.join(log_dir, 'val.step.*.json')) 37 | epoch_jsons = glob.glob(os.path.join(log_dir, 'val.epoch.*.json')) 38 | 39 | val_names, val_scores = [], [] 40 | for i, json_file in enumerate(step_jsons + epoch_jsons): 41 | json_name = os.path.basename(json_file) 42 | scores = json.load(open(json_file)) 43 | val_names.append(json_name) 44 | val_scores.append(scores) 45 | 46 | measure_names = list(val_scores[0].keys()) 47 | model_files = {} 48 | for measure_name in measure_names: 49 | # for metrics: the lower the better 50 | if 'loss' in measure_name or 'medr' in measure_name or 'meanr' in measure_name: 51 | idx = np.argmin([scores[measure_name] for scores in val_scores]) 52 | # for metrics: the higher the better 53 | else: 54 | idx = np.argmax([scores[measure_name] for scores in val_scores]) 55 | json_name = val_names[idx] 56 | model_file = os.path.join(model_dir, 57 | 'epoch.%s.th'%(json_name.split('.')[2]) if 'epoch' in json_name \ 58 | else 'step.%s.th'%(json_name.split('.')[2])) 59 | model_files.setdefault(model_file, []) 60 | model_files[model_file].append(measure_name) 61 | 62 | name2file = {'-'.join(measure_name): model_file for model_file, measure_name in model_files.items()} 63 | 64 | return name2file 65 | --------------------------------------------------------------------------------