├── .gitignore ├── LICENSE ├── README.md ├── dataloader.py ├── detect_concepts.py ├── eval_cls_rnn.py ├── eval_ppl.py ├── eval_senti.py ├── method_figs ├── Architecture.png ├── Fine-tuning.png ├── Image Captioning with Inherent Sentiment.pdf └── Pre-training.png ├── models ├── __init__.py ├── captioner.py ├── concept_detector.py ├── decoder.py ├── encoder.py ├── sent_senti_cls.py ├── sentiment_detector.py └── sentiment_detector_full.py ├── opts.py ├── preprocess.py ├── self_critical ├── __init__.py ├── bleu │ ├── LICENSE │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── README.md │ ├── __init__.py │ ├── license.txt │ └── pyciderevalcap │ │ ├── __init__.py │ │ └── ciderD │ │ ├── __init__.py │ │ ├── ciderD.py │ │ └── ciderD_scorer.py ├── utils.py └── utils_bac.py ├── test_cpt.py ├── train_cpt.py ├── train_rl.py ├── train_sent_senti_cls_rnn.py ├── train_senti.py └── train_xe.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | checkpoint/* 3 | result/* 4 | .idea/* 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Li Tong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InSentiCap_model 2 | A pytorch implementation of our paper [Image Captioning with Inherent Sentiment (ICME 2021 Oral)](./method_figs/Image%20Captioning%20with%20Inherent%20Sentiment.pdf). 3 | 4 | ### Citation 5 | ``` 6 | @inproceedings{li2021image, 7 | title={Image Captioning with Inherent Sentiment}, 8 | author={Li, Tong and Hu, Yunhui and Wu, Xinxiao}, 9 | booktitle={2021 IEEE International Conference on Multimedia and Expo (ICME)}, 10 | year={2021}, 11 | organization={IEEE} 12 | } 13 | ``` 14 | 15 | ## Environment 16 | - Python 3.7 17 | - Pytorch 1.3.1 18 | 19 | ## Method 20 | ### 1. Architecture 21 | ![Architecture](./method_figs/Architecture.png) 22 | 23 | ### 2. Train Strategy 24 | - Pre-training stage 25 | ![Pre-training](./method_figs/Pre-training.png) 26 | - Fine-tuning stage 27 | ![Fine-tuning](./method_figs/Fine-tuning.png) 28 | 29 | ## Result 30 | ### Evaluation metrics 31 | 32 | |Sentiment|Bleu-1|Bleu-3|METEOR|CIDEr|ppl(↓)|cls(%)| 33 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:| 34 | |positive|59.7|25.3|20.9|61.3|13.0|98.5| 35 | |negative|59.1|24.3|19.4|53.3|12.3|95.5| 36 | |neutral|73.5|41.2|24.7|97.5|8.4|98.9| 37 | 38 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import h5py 6 | import random 7 | 8 | 9 | def create_collate_fn(name, pad_index=0, max_seq_len=17, num_concepts=5, 10 | num_sentiments=10): 11 | def caption_collate_fn(dataset): 12 | tmp = [] 13 | for fn, fc_feat, att_feat, caps_idx, cpts_idx in dataset: 14 | for cap in caps_idx: 15 | tmp.append([fn, fc_feat, att_feat, cap, cpts_idx]) 16 | dataset = tmp 17 | dataset.sort(key=lambda p: len(p[3]), reverse=True) 18 | fns, fc_feats, att_feats, caps, cpts = zip(*dataset) 19 | fc_feats = torch.FloatTensor(np.array(fc_feats)) 20 | att_feats = torch.FloatTensor(np.array(att_feats)) 21 | 22 | lengths = [min(len(c), max_seq_len) for c in caps] 23 | caps_tensor = torch.LongTensor(len(caps), lengths[0]).fill_(pad_index) 24 | for i, c in enumerate(caps): 25 | end = lengths[i] 26 | caps_tensor[i, :end] = torch.LongTensor(c[:end]) 27 | lengths = [l-1 for l in lengths] 28 | 29 | cpts_tensor = torch.LongTensor(len(cpts), num_concepts).fill_(pad_index) 30 | for i, c in enumerate(cpts): 31 | end = min(len(c), num_concepts) 32 | cpts_tensor[i, :end] = torch.LongTensor(c[:end]) 33 | 34 | return fns, fc_feats, att_feats, (caps_tensor, lengths), cpts_tensor 35 | 36 | def scs_collate_fn(dataset): 37 | dataset.sort(key=lambda p: len(p[0]), reverse=True) 38 | caps, cpts, sentis, senti_ids = zip(*dataset) 39 | senti_ids = torch.LongTensor(np.array(senti_ids)) 40 | 41 | lengths = [min(len(c), max_seq_len) for c in caps] 42 | caps_tensor = torch.LongTensor(len(caps), lengths[0]).fill_(pad_index) 43 | for i, c in enumerate(caps): 44 | end = lengths[i] 45 | caps_tensor[i, :end] = torch.LongTensor(c[:end]) 46 | lengths = [l-1 for l in lengths] 47 | 48 | cpts_tensor = torch.LongTensor(len(cpts), num_concepts).fill_(pad_index) 49 | for i, c in enumerate(cpts): 50 | end = min(len(c), num_concepts) 51 | cpts_tensor[i, :end] = torch.LongTensor(c[:end]) 52 | 53 | sentis_tensor = torch.LongTensor(len(sentis), num_sentiments).fill_(pad_index) 54 | for i, c in enumerate(sentis): 55 | end = min(len(c), num_sentiments) 56 | sentis_tensor[i, :end] = torch.LongTensor(c[:end]) 57 | 58 | return (caps_tensor, lengths), cpts_tensor, sentis_tensor, senti_ids 59 | 60 | def rl_fact_collate_fn(dataset): 61 | ground_truth = {} 62 | tmp = [] 63 | for fn, caps_idx, fc_feat, att_feat, cpts_idx, sentis_idx in dataset: 64 | ground_truth[fn] = [c[:max_seq_len] for c in caps_idx] 65 | cap = random.sample(caps_idx, 1)[0] 66 | tmp.append([fn, cap, fc_feat, att_feat, cpts_idx, sentis_idx]) 67 | dataset = tmp 68 | dataset.sort(key=lambda p: len(p[1]), reverse=True) 69 | 70 | fns, caps, fc_feats, att_feats, cpts, sentis = zip(*dataset) 71 | fc_feats = torch.FloatTensor(np.array(fc_feats)) 72 | att_feats = torch.FloatTensor(np.array(att_feats)) 73 | 74 | lengths = [min(len(c), max_seq_len) for c in caps] 75 | caps_tensor = torch.LongTensor(len(caps), lengths[0]).fill_(pad_index) 76 | for i, c in enumerate(caps): 77 | end = lengths[i] 78 | caps_tensor[i, :end] = torch.LongTensor(c[:end]) 79 | lengths = [l - 1 for l in lengths] 80 | 81 | cpts_tensor = torch.LongTensor(len(cpts), num_concepts).fill_(pad_index) 82 | for i, c in enumerate(cpts): 83 | end = min(len(c), num_concepts) 84 | cpts_tensor[i, :end] = torch.LongTensor(c[:end]) 85 | 86 | sentis_tensor = torch.LongTensor(len(sentis), num_sentiments).fill_(pad_index) 87 | for i, s in enumerate(sentis): 88 | end = min(len(s), num_sentiments) 89 | sentis_tensor[i, :end] = torch.LongTensor(s[:end]) 90 | 91 | return fns, fc_feats, att_feats, (caps_tensor, lengths), cpts_tensor, sentis_tensor, ground_truth 92 | 93 | def rl_senti_collate_fn(dataset): 94 | fns, fc_feats, att_feats, cpts, sentis, senti_labels = zip(*dataset) 95 | fc_feats = torch.FloatTensor(np.array(fc_feats)) 96 | att_feats = torch.FloatTensor(np.array(att_feats)) 97 | senti_labels = torch.LongTensor(np.array(senti_labels)) 98 | 99 | cpts_tensor = torch.LongTensor(len(cpts), num_concepts).fill_(pad_index) 100 | for i, c in enumerate(cpts): 101 | end = min(len(c), num_concepts) 102 | cpts_tensor[i, :end] = torch.LongTensor(c[:end]) 103 | 104 | sentis_tensor = torch.LongTensor(len(sentis), num_sentiments).fill_(pad_index) 105 | for i, s in enumerate(sentis): 106 | end = min(len(s), num_sentiments) 107 | sentis_tensor[i, :end] = torch.LongTensor(s[:end]) 108 | 109 | return fns, fc_feats, att_feats, cpts_tensor, sentis_tensor, senti_labels 110 | 111 | def concept_collate_fn(dataset): 112 | fns, fc_feats, cpts = zip(*dataset) 113 | fc_feats = torch.FloatTensor(np.array(fc_feats)) 114 | cpts_tensors = torch.LongTensor(np.array(cpts)) 115 | return fns, fc_feats, cpts_tensors 116 | 117 | def senti_image_collate_fn(dataset): 118 | fns, att_feats, labels = zip(*dataset) 119 | att_feats = torch.FloatTensor(np.array(att_feats)) 120 | labels = torch.LongTensor(np.array(labels)) 121 | return fns, att_feats, labels 122 | 123 | def senti_sents_collate_fn(dataset): 124 | dataset.sort(key=lambda p: len(p[1]), reverse=True) 125 | sentis, caps = zip(*dataset) 126 | sentis = torch.LongTensor(np.array(sentis)) 127 | 128 | lengths = [min(len(c), max_seq_len) for c in caps] 129 | caps_tensor = torch.LongTensor(len(caps), lengths[0]).fill_(pad_index) 130 | for i, c in enumerate(caps): 131 | end = lengths[i] 132 | caps_tensor[i, :end] = torch.LongTensor(c[:end]) 133 | 134 | return sentis, (caps_tensor, lengths) 135 | 136 | if name == 'caption': 137 | return caption_collate_fn 138 | elif name == 'senti_sents': 139 | return senti_sents_collate_fn 140 | elif name == 'concept': 141 | return concept_collate_fn 142 | elif name == 'senti_image': 143 | return senti_image_collate_fn 144 | elif name == 'rl_fact': 145 | return rl_fact_collate_fn 146 | elif name == 'rl_senti': 147 | return rl_senti_collate_fn 148 | elif name == 'senti_corpus_with_sentis': 149 | return scs_collate_fn 150 | 151 | 152 | class SCSDataset(data.Dataset): 153 | def __init__(self, senti_corpus_with_sentis): 154 | self.senti_corpus_with_sentis = senti_corpus_with_sentis 155 | 156 | def __getitem__(self, index): 157 | cap, cpts, sentis, senti_id = self.senti_corpus_with_sentis[index] 158 | return cap, cpts, sentis, senti_id 159 | 160 | def __len__(self): 161 | return len(self.senti_corpus_with_sentis) 162 | 163 | 164 | class CaptionDataset(data.Dataset): 165 | def __init__(self, fc_feats, att_feats, img_captions, img_det_concepts): 166 | self.fc_feats = fc_feats 167 | self.att_feats = att_feats 168 | self.captions = list(img_captions.items()) # [(fn, [[1, 2],[3, 4],...]),...] 169 | self.det_concepts = img_det_concepts # {fn: [1,2,...])} 170 | 171 | def __getitem__(self, index): 172 | fn, caps = self.captions[index] 173 | f_fc = h5py.File(self.fc_feats, mode='r') 174 | f_att = h5py.File(self.att_feats, mode='r') 175 | fc_feat = f_fc[fn][:] 176 | att_feat = f_att[fn][:] 177 | cpts = self.det_concepts[fn] 178 | return fn, np.array(fc_feat), np.array(att_feat), caps, cpts 179 | 180 | def __len__(self): 181 | return len(self.captions) 182 | 183 | 184 | class RLFactDataset(data.Dataset): 185 | def __init__(self, fc_feats, att_feats, img_captions, 186 | img_det_concepts, img_det_sentiments): 187 | self.fc_feats = fc_feats 188 | self.att_feats = att_feats 189 | self.captions = list(img_captions.items()) 190 | self.det_concepts = img_det_concepts # {fn: [1,2,...])} 191 | self.det_sentiments = img_det_sentiments # {fn: [5,10,...])} 192 | 193 | def __getitem__(self, index): 194 | fn, caps = self.captions[index] 195 | f_fc = h5py.File(self.fc_feats, mode='r') 196 | f_att = h5py.File(self.att_feats, mode='r') 197 | fc_feat = f_fc[fn][:] 198 | att_feat = f_att[fn][:] 199 | cpts = self.det_concepts[fn] 200 | sentis = self.det_sentiments[fn] 201 | return fn, caps, np.array(fc_feat), np.array(att_feat), cpts, sentis 202 | 203 | def __len__(self): 204 | return len(self.captions) 205 | 206 | 207 | class RLSentiDataset(data.Dataset): 208 | def __init__(self, fc_feats, att_feats, img_det_concepts, 209 | img_det_sentiments, img_senti_labels): 210 | self.fc_feats = fc_feats 211 | self.att_feats = att_feats 212 | self.det_concepts = img_det_concepts # {fn: [1,2,...])} 213 | self.det_sentiments = img_det_sentiments # {fn: [5,10,...])} 214 | self.img_senti_labels = img_senti_labels # [(fn, senti_label),...] 215 | 216 | def __getitem__(self, index): 217 | fn, senti_label = self.img_senti_labels[index] 218 | f_fc = h5py.File(self.fc_feats, mode='r') 219 | f_att = h5py.File(self.att_feats, mode='r') 220 | fc_feat = f_fc[fn][:] 221 | att_feat = f_att[fn][:] 222 | cpts = self.det_concepts[fn] 223 | sentis = self.det_sentiments[fn] 224 | return fn, np.array(fc_feat), np.array(att_feat), cpts, sentis, senti_label 225 | 226 | def __len__(self): 227 | return len(self.img_senti_labels) 228 | 229 | 230 | class ConceptDataset(data.Dataset): 231 | def __init__(self, fc_feats, img_concepts, num_cpts): 232 | self.fc_feats = fc_feats 233 | self.concepts = list(img_concepts.items()) 234 | self.num_cpts = num_cpts 235 | 236 | def __getitem__(self, index): 237 | fn, cpts_idx = self.concepts[index] 238 | f_fc = h5py.File(self.fc_feats, mode='r') 239 | fc_feat = f_fc[fn][:] 240 | cpts = np.zeros(self.num_cpts, dtype=np.int16) 241 | cpts[cpts_idx] = 1 242 | return fn, np.array(fc_feat), cpts 243 | 244 | def __len__(self): 245 | return len(self.concepts) 246 | 247 | 248 | class SentiImageDataset(data.Dataset): 249 | def __init__(self, senti_att_feats, img_senti_labels): 250 | self.att_feats = senti_att_feats 251 | self.img_senti_labels = img_senti_labels # [(fn, senti_label),...] 252 | 253 | def __getitem__(self, index): 254 | fn, senti_label = self.img_senti_labels[index] 255 | f_att = h5py.File(self.att_feats, mode='r') 256 | att_feat = f_att[fn][:] 257 | return fn, np.array(att_feat), senti_label 258 | 259 | def __len__(self): 260 | return len(self.img_senti_labels) 261 | 262 | 263 | class SentiSentDataset(data.Dataset): 264 | def __init__(self, senti_sentences): 265 | self.senti_sentences = senti_sentences 266 | 267 | def __getitem__(self, index): 268 | senti, sent = self.senti_sentences[index] 269 | return senti, np.array(sent) 270 | 271 | def __len__(self): 272 | return len(self.senti_sentences) 273 | 274 | 275 | def get_caption_dataloader(fc_feats, att_feats, img_captions, img_det_concepts, 276 | pad_index, max_seq_len, num_concepts, 277 | batch_size, num_workers=0, shuffle=True): 278 | dataset = CaptionDataset(fc_feats, att_feats, img_captions, img_det_concepts) 279 | dataloader = data.DataLoader(dataset, 280 | batch_size=batch_size, 281 | shuffle=shuffle, 282 | num_workers=num_workers, 283 | collate_fn=create_collate_fn( 284 | 'caption', pad_index, max_seq_len + 1, 285 | num_concepts)) 286 | return dataloader 287 | 288 | 289 | def get_senti_corpus_with_sentis_dataloader(senti_corpus_with_sentis, 290 | pad_index, max_seq_len, num_concepts, num_sentiments, 291 | batch_size, num_workers=0, shuffle=True): 292 | dataset = SCSDataset(senti_corpus_with_sentis) 293 | dataloader = data.DataLoader(dataset, 294 | batch_size=batch_size, 295 | shuffle=shuffle, 296 | num_workers=num_workers, 297 | collate_fn=create_collate_fn( 298 | 'senti_corpus_with_sentis', pad_index, max_seq_len + 1, 299 | num_concepts=num_concepts, 300 | num_sentiments=num_sentiments)) 301 | return dataloader 302 | 303 | 304 | def get_rl_fact_dataloader(fc_feats, att_feats, img_captions, img_det_concepts, 305 | img_det_sentiments, pad_index, max_seq_len, num_concepts, 306 | num_sentiments, batch_size, num_workers=0, shuffle=True): 307 | dataset = RLFactDataset(fc_feats, att_feats, img_captions, 308 | img_det_concepts, img_det_sentiments) 309 | dataloader = data.DataLoader(dataset, 310 | batch_size=batch_size, 311 | shuffle=shuffle, 312 | num_workers=num_workers, 313 | collate_fn=create_collate_fn( 314 | 'rl_fact', pad_index=pad_index, 315 | max_seq_len=max_seq_len + 1, 316 | num_concepts=num_concepts, 317 | num_sentiments=num_sentiments)) 318 | return dataloader 319 | 320 | 321 | def get_rl_senti_dataloader(fc_feats, att_feats, img_det_concepts, 322 | img_det_sentiments, img_senti_labels, pad_index, 323 | num_concepts, num_sentiments, batch_size, 324 | num_workers=0, shuffle=True): 325 | dataset = RLSentiDataset(fc_feats, att_feats, img_det_concepts, 326 | img_det_sentiments, img_senti_labels) 327 | dataloader = data.DataLoader(dataset, 328 | batch_size=batch_size, 329 | shuffle=shuffle, 330 | num_workers=num_workers, 331 | collate_fn=create_collate_fn( 332 | 'rl_senti', pad_index=pad_index, 333 | num_concepts=num_concepts, 334 | num_sentiments=num_sentiments)) 335 | return dataloader 336 | 337 | 338 | def get_concept_dataloader(fc_feats, img_concepts, num_cpts, 339 | batch_size, num_workers=0, shuffle=True): 340 | dataset = ConceptDataset(fc_feats, img_concepts, num_cpts) 341 | dataloader = data.DataLoader(dataset, 342 | batch_size=batch_size, 343 | shuffle=shuffle, 344 | num_workers=num_workers, 345 | collate_fn=create_collate_fn('concept')) 346 | return dataloader 347 | 348 | 349 | def get_senti_image_dataloader(senti_att_feats, img_senti_labels, 350 | batch_size, num_workers=0, shuffle=True): 351 | dataset = SentiImageDataset(senti_att_feats, img_senti_labels) 352 | dataloader = data.DataLoader(dataset, 353 | batch_size=batch_size, 354 | shuffle=shuffle, 355 | num_workers=num_workers, 356 | collate_fn=create_collate_fn('senti_image')) 357 | return dataloader 358 | 359 | 360 | def get_senti_sents_dataloader(senti_sentences, pad_index, max_seq_len, 361 | batch_size=80, num_workers=2, shuffle=True): 362 | dataset = SentiSentDataset(senti_sentences) 363 | dataloader = data.DataLoader(dataset, 364 | batch_size=batch_size, 365 | shuffle=shuffle, 366 | num_workers=num_workers, 367 | collate_fn=create_collate_fn( 368 | 'senti_sents', pad_index=pad_index, 369 | max_seq_len=max_seq_len)) 370 | return dataloader 371 | -------------------------------------------------------------------------------- /detect_concepts.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | import json 4 | import tqdm 5 | import os 6 | import h5py 7 | import numpy as np 8 | 9 | from opts import parse_opt 10 | from models.concept_detector import ConceptDetector 11 | from dataloader import get_concept_dataloader 12 | 13 | 14 | opt = parse_opt() 15 | print("====> loading checkpoint '{}'".format(opt.eval_model)) 16 | chkpoint = torch.load(opt.eval_model, map_location=lambda s, l: s) 17 | idx2concept = chkpoint['idx2concept'] 18 | settings = chkpoint['settings'] 19 | dataset_name = chkpoint['dataset_name'] 20 | model = ConceptDetector(idx2concept, settings) 21 | model.to(opt.device) 22 | model.load_state_dict(chkpoint['model']) 23 | model.eval() 24 | _, criterion = model.get_optim_criterion(0) 25 | print("====> loaded checkpoint '{}', epoch: {}, dataset_name: {}". 26 | format(opt.eval_model, chkpoint['epoch'], dataset_name)) 27 | 28 | 29 | fact_fc = h5py.File(os.path.join(opt.feats_dir, dataset_name, '%s_fc.h5' % dataset_name), 'r') 30 | senti_fc = h5py.File(os.path.join(opt.feats_dir, 'sentiment', 'feats_fc.h5'), 'r') 31 | 32 | predict_result = {} 33 | for fc in [fact_fc, senti_fc]: 34 | fns = list(fc.keys()) 35 | for i in tqdm.tqdm(range(0, len(fns), 100)): 36 | cur_fns = fns[i:i + 100] 37 | feats = [] 38 | for fn in cur_fns: 39 | feats.append(fc[fn][:]) 40 | feats = torch.FloatTensor(np.array(feats)).to(opt.device) 41 | _, concepts, _ = model.sample(feats, num=opt.num_concepts) 42 | for j, fn in enumerate(cur_fns): 43 | predict_result[fn] = concepts[j] 44 | 45 | json.dump(predict_result, open(os.path.join(opt.captions_dir, dataset_name, 'img_det_concepts.json'), 'w')) 46 | -------------------------------------------------------------------------------- /eval_cls_rnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import tqdm 4 | import numpy as np 5 | import os 6 | 7 | from models.sent_senti_cls import SentenceSentimentClassifier 8 | from dataloader import get_senti_sents_dataloader 9 | 10 | device = torch.device('cuda:0') 11 | max_seq_len = 16 12 | 13 | 14 | def compute_cls(captions_file_prefix, data_type): 15 | dataset_name = 'coco' 16 | if 'flickr30k' in captions_file_prefix: 17 | dataset_name = 'flickr30k' 18 | corpus_type = 'part' 19 | if 'full' in captions_file_prefix: 20 | corpus_type = 'full' 21 | 22 | ss_cls_file = os.path.join('./checkpoint', 'sent_senti_cls', dataset_name, corpus_type, 'model-best.pth') 23 | print("====> loading checkpoint '{}'".format(ss_cls_file)) 24 | chkpoint = torch.load(ss_cls_file, map_location=lambda s, l: s) 25 | settings = chkpoint['settings'] 26 | idx2word = chkpoint['idx2word'] 27 | sentiment_categories = chkpoint['sentiment_categories'] 28 | assert dataset_name == chkpoint['dataset_name'], \ 29 | 'dataset_name and resume model dataset_name are different' 30 | assert corpus_type == chkpoint['corpus_type'], \ 31 | 'corpus_type and resume model corpus_type are different' 32 | model = SentenceSentimentClassifier(idx2word, sentiment_categories, settings) 33 | model.load_state_dict(chkpoint['model']) 34 | model.eval() 35 | model.to(device) 36 | 37 | val_sets = {} 38 | val_sets['all'] = [] 39 | for senti_id, senti in enumerate(sentiment_categories): 40 | val_sets[senti] = [] 41 | fn = '%s_%s_%s.txt' % (captions_file_prefix, senti, data_type) 42 | with open(fn, 'r') as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.split() 46 | line = [int(l) for l in line] 47 | val_sets[senti].append([senti_id, line]) 48 | val_sets['all'].append([senti_id, line]) 49 | 50 | val_datas = {} 51 | for senti in val_sets: 52 | val_datas[senti] = get_senti_sents_dataloader(val_sets[senti], idx2word.index(''), max_seq_len, 53 | shuffle=False) 54 | 55 | for senti, val_data in val_datas.items(): 56 | all_num = 0 57 | wrong_num = 0 58 | with torch.no_grad(): 59 | for sentis, (caps_tensor, lengths) in tqdm.tqdm(val_data): 60 | sentis = sentis.to(device) 61 | caps_tensor = caps_tensor.to(device) 62 | 63 | rest, _, _ = model.sample(caps_tensor, lengths) 64 | rest = torch.LongTensor(np.array(rest)).to(device) 65 | all_num += int(sentis.size(0)) 66 | wrong_num += int((sentis != rest).sum()) 67 | wrong_rate = wrong_num / all_num 68 | print('%s acc_rate: %.6f' % (senti, 1 - wrong_rate)) 69 | 70 | 71 | if __name__ == "__main__": 72 | compute_cls(sys.argv[1], sys.argv[2]) 73 | -------------------------------------------------------------------------------- /eval_ppl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import numpy as np 4 | import os 5 | # import kenlm 6 | 7 | sentis = ['positive', 'negative', 'neutral'] 8 | lm_cmd = 'ngram -ppl %s -lm ./data/captions/%s/%s/lm/%s_w.sri' 9 | 10 | 11 | def compute_ppl(captions_file_prefix, data_type): 12 | dataset_name = 'coco' 13 | if 'flickr30k' in captions_file_prefix: 14 | dataset_name = 'flickr30k' 15 | corpus_type = 'part' 16 | if 'full' in captions_file_prefix: 17 | corpus_type = 'full' 18 | 19 | lm_cmds = {} 20 | for senti in sentis: 21 | lm_cmds[senti] = lm_cmd % ('%s_%s_%s_w.txt' % (captions_file_prefix, senti, data_type), dataset_name, corpus_type, senti) 22 | # print('lm cms:', lm_cmds) 23 | scores = {} 24 | for senti, cmd in lm_cmds.items(): 25 | out = os.popen(cmd).read().split() 26 | try: 27 | scores[senti] = float(out[out.index('ppl=') + 1]) 28 | except Exception: 29 | scores[senti] = 0 30 | 31 | print('ppl scores:', scores) 32 | print('ppl scores sum:', sum(scores.values())) 33 | return scores 34 | 35 | 36 | if __name__ == "__main__": 37 | compute_ppl(sys.argv[1], sys.argv[2]) 38 | -------------------------------------------------------------------------------- /eval_senti.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | from collections import defaultdict 5 | 6 | from models.sentiment_detector import SentimentDetector 7 | from dataloader import get_senti_image_dataloader 8 | from opts import parse_opt 9 | 10 | 11 | labeled_file = './data/labeled_data/at_most_one_disagree.json' 12 | labeled_data = json.load(open(labeled_file, 'r')) 13 | 14 | opt = parse_opt() 15 | print("====> loading rl_senti_resume '{}'".format(opt.rl_senti_resume)) 16 | ch = torch.load(opt.rl_senti_resume, map_location=lambda s, l: s) 17 | settings = ch['settings'] 18 | sentiment_categories = ch['sentiment_categories'] 19 | model = SentimentDetector(sentiment_categories, settings) 20 | model.load_state_dict(ch['model']) 21 | model.to(opt.device) 22 | model.eval() 23 | 24 | senti_label2idx = {} 25 | for i, w in enumerate(sentiment_categories): 26 | senti_label2idx[w] = i 27 | neu_idx = senti_label2idx['neutral'] 28 | img_senti_labels = {} 29 | for senti, fns in labeled_data.items(): 30 | senti_id = senti_label2idx[senti] 31 | img_senti_labels[senti] = [[fn, senti_id] for fn in fns] 32 | 33 | dataset_name = 'coco' 34 | att_feats = os.path.join(opt.feats_dir, dataset_name, '%s_att.h5' % dataset_name) 35 | eval_datas = {} 36 | for senti in img_senti_labels: 37 | data = get_senti_image_dataloader( 38 | att_feats, img_senti_labels[senti], batch_size=len(img_senti_labels[senti]), 39 | num_workers=2, shuffle=False) 40 | eval_datas[senti] = next(iter(data)) 41 | 42 | for THRESHOLD in range(11): 43 | THRESHOLD = THRESHOLD / 10 44 | print('THRESHOLD:', THRESHOLD) 45 | all_num = 0 46 | all_cor_num = 0 47 | for senti, (_, att_feats, labels) in eval_datas.items(): 48 | att_feats = att_feats.to(opt.device) 49 | labels = labels.to(opt.device) 50 | with torch.no_grad(): 51 | preds, _, _, scores = model.sample(att_feats) 52 | replace_idx = (scores < THRESHOLD).nonzero(as_tuple=False).view(-1) 53 | preds.index_copy_(0, replace_idx, preds.new_zeros(len(replace_idx)).fill_(neu_idx)) 54 | num = int(preds.size(0)) 55 | cor_num = int(sum(preds == labels)) 56 | print('%s accuracy: %s' % (senti, cor_num / num)) 57 | # print('%s scores mean: %s' % (senti, scores.mean())) 58 | all_num += num 59 | all_cor_num += cor_num 60 | print('all accuracy:', all_cor_num / all_num) 61 | 62 | 63 | for THRESHOLD in range(10): 64 | THRESHOLD = THRESHOLD / 10 65 | print('THRESHOLD:', THRESHOLD) 66 | all_num = defaultdict(int) 67 | all_cor_num = defaultdict(int) 68 | for senti, (_, att_feats, labels) in eval_datas.items(): 69 | att_feats = att_feats.to(opt.device) 70 | labels = labels.to(opt.device) 71 | with torch.no_grad(): 72 | preds, _, _, scores = model.sample(att_feats) 73 | replace_idx = (scores < THRESHOLD).nonzero(as_tuple=False).view(-1) 74 | preds.index_copy_(0, replace_idx, preds.new_zeros(len(replace_idx)).fill_(neu_idx)) 75 | for idx in [0, 1, 2]: 76 | all_num[idx] += int(sum(preds == idx)) 77 | label = int(labels[0]) 78 | all_cor_num[label] += int(sum(preds == label)) 79 | for senti_id in all_num: 80 | senti = sentiment_categories[senti_id] 81 | print('%s precision: %s' % (senti, all_cor_num[senti_id] / (all_num[senti_id] + 1e-9))) 82 | print('all precision:', sum(all_cor_num.values()) / sum(all_num.values())) 83 | for senti_id in all_num: 84 | senti = sentiment_categories[senti_id] 85 | print('%s all num: %s, cor num: %s' % (senti, all_num[senti_id], all_cor_num[senti_id])) 86 | -------------------------------------------------------------------------------- /method_figs/Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/method_figs/Architecture.png -------------------------------------------------------------------------------- /method_figs/Fine-tuning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/method_figs/Fine-tuning.png -------------------------------------------------------------------------------- /method_figs/Image Captioning with Inherent Sentiment.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/method_figs/Image Captioning with Inherent Sentiment.pdf -------------------------------------------------------------------------------- /method_figs/Pre-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/method_figs/Pre-training.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/models/__init__.py -------------------------------------------------------------------------------- /models/captioner.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | import torch.nn.functional as F 6 | 7 | 8 | BeamCandidate = namedtuple('BeamCandidate', 9 | ['state', 'log_prob_sum', 'log_prob_seq', 'last_word_id', 'word_id_seq']) 10 | 11 | 12 | class ContentAttention(nn.Module): 13 | def __init__(self, settings): 14 | super(ContentAttention, self).__init__() 15 | self.h2att = nn.Linear(settings['rnn_hid_dim'], settings['att_hid_dim']) 16 | self.att_alpha = nn.Linear(settings['att_hid_dim'], 1) 17 | 18 | self.weights = [] 19 | 20 | def _reset_weights(self): 21 | self.weights = [] 22 | 23 | def forward(self, h, att_feats, p_att_feats): 24 | # The p_att_feats/p_cpt_feats here are already projected 25 | h_att = self.h2att(h) # [bs, att_hid] 26 | h_att = h_att.unsqueeze(1).expand_as(p_att_feats) # [bs, num_atts, att_hid] 27 | p_att_feats = p_att_feats + h_att # [bs, num_atts, att_hid] 28 | p_att_feats = p_att_feats.tanh() 29 | p_att_feats = self.att_alpha(p_att_feats).squeeze(-1) # [bs, num_atts] 30 | # p_att_feats = p_att_feats.view(-1, att_size) # [bs, num_atts] 31 | weight = p_att_feats.softmax(-1) 32 | self.weights.append(weight) 33 | 34 | att_res = weight.unsqueeze(1).bmm(att_feats).squeeze(1) # [bs, feat_emb] 35 | return att_res 36 | 37 | 38 | class SentiAttention(nn.Module): 39 | def __init__(self, settings): 40 | super(SentiAttention, self).__init__() 41 | self.h2word = nn.Linear(settings['rnn_hid_dim'], settings['att_hid_dim']) 42 | self.label2word = nn.Linear(settings['word_emb_dim'], settings['att_hid_dim']) 43 | self.word_alpha = nn.Linear(settings['att_hid_dim'], 1) 44 | 45 | self.weights = [] 46 | 47 | def _reset_weights(self): 48 | self.weights = [] 49 | 50 | def forward(self, h, senti_word_feats, p_senti_word_feats, senti_labels): 51 | h_word = self.h2word(h) # [bs, att_hid] 52 | senti_labels_word = self.label2word(senti_labels) # [bs, att_hid] 53 | h_word = h_word.unsqueeze(1).expand_as(p_senti_word_feats) # [bs, num_stmts, att_hid] 54 | senti_labels_word = senti_labels_word.unsqueeze(1).expand_as(p_senti_word_feats) # [bs, num_stmts, att_hid] 55 | p_senti_word_feats = p_senti_word_feats + h_word + senti_labels_word # [bs, num_stmts, att_hid] 56 | p_senti_word_feats = p_senti_word_feats.tanh() 57 | p_senti_word_feats = self.word_alpha(p_senti_word_feats).squeeze(-1) # [bs, num_stmts] 58 | weight = p_senti_word_feats.softmax(-1) 59 | self.weights.append(weight) 60 | 61 | word_res = weight.unsqueeze(1).bmm(senti_word_feats).squeeze(1) # [bs, word_emb] 62 | return word_res 63 | 64 | 65 | class Attention(nn.Module): 66 | def __init__(self, settings): 67 | super(Attention, self).__init__() 68 | self.cont_att = ContentAttention(settings) 69 | self.senti_att = SentiAttention(settings) 70 | 71 | self.h2att = nn.Linear(settings['rnn_hid_dim'], settings['att_hid_dim']) 72 | self.cont2att = nn.Linear(settings['feat_emb_dim'], settings['att_hid_dim']) 73 | self.senti2att = nn.Linear(settings['feat_emb_dim'], settings['att_hid_dim']) 74 | self.att_alpha = nn.Linear(settings['att_hid_dim'], 1) 75 | 76 | self.weights = [] 77 | 78 | def _reset_weights(self): 79 | self.weights = [] 80 | self.cont_att._reset_weights() 81 | self.senti_att._reset_weights() 82 | 83 | def _get_weights(self): 84 | cont_weights = self.cont_att.weights 85 | if cont_weights: 86 | cont_weights = torch.cat(cont_weights, dim=1) 87 | senti_weights = self.senti_att.weights 88 | if senti_weights: 89 | senti_weights = torch.cat(senti_weights, dim=1) 90 | cont_senti_weights = self.weights 91 | if cont_senti_weights: 92 | cont_senti_weights = torch.cat(cont_senti_weights, dim=1) 93 | self._reset_weights() 94 | return cont_weights, senti_weights, cont_senti_weights 95 | 96 | def forward(self, h, att_feats, p_att_feats, senti_word_feats, 97 | p_senti_word_feats, senti_labels): 98 | if att_feats is None: # for seq2seq 99 | senti_res = self.senti_att(h, senti_word_feats, p_senti_word_feats, senti_labels) # [bs, feat_emb] 100 | return senti_res 101 | cont_res = self.cont_att(h, att_feats, p_att_feats) # [bs, feat_emb] 102 | if senti_word_feats is None: # for xe 103 | return cont_res 104 | 105 | # for rl 106 | senti_res = self.senti_att(h, senti_word_feats, p_senti_word_feats, senti_labels) # [bs, feat_emb] 107 | 108 | h_att = self.h2att(h) # [bs, att_hid] 109 | cont_att = self.cont2att(cont_res) # [bs, att_hid] 110 | senti_att = self.senti2att(senti_res) # [bs, att_hid] 111 | weight = cont_att + senti_att + h_att # [bs, att_hid] 112 | weight = weight.tanh() 113 | weight = self.att_alpha(weight).sigmoid() # [bs, 1] 114 | # weight = (weight > 0.5).type(weight.dtype) 115 | self.weights.append(weight) 116 | 117 | res = weight * cont_res + (1 - weight) * senti_res 118 | return res 119 | 120 | 121 | class Captioner(nn.Module): 122 | def __init__(self, idx2word, sentiment_categories, settings): 123 | super(Captioner, self).__init__() 124 | self.idx2word = idx2word 125 | self.pad_id = idx2word.index('') 126 | self.unk_id = idx2word.index('') 127 | self.sos_id = idx2word.index('') if '' in idx2word else self.pad_id 128 | self.eos_id = idx2word.index('') if '' in idx2word else self.pad_id 129 | self.neu_idx = sentiment_categories.index('neutral') 130 | 131 | self.vocab_size = len(idx2word) 132 | self.drop = nn.Dropout(settings['dropout_p']) 133 | self.word_embed = nn.Sequential(nn.Embedding(self.vocab_size, settings['word_emb_dim'], 134 | padding_idx=self.pad_id), 135 | nn.ReLU()) 136 | self.senti_label_embed = nn.Sequential(nn.Embedding(len(sentiment_categories), settings['word_emb_dim']), 137 | nn.ReLU()) 138 | self.fc_embed = nn.Sequential(nn.Linear(settings['fc_feat_dim'], settings['feat_emb_dim']), 139 | nn.ReLU()) 140 | self.cpt2fc = nn.Sequential(nn.Linear(settings['word_emb_dim'], settings['feat_emb_dim']), 141 | nn.ReLU()) 142 | self.att_embed = nn.Sequential(nn.Linear(settings['att_feat_dim'], settings['feat_emb_dim']), 143 | nn.ReLU()) 144 | # self.senti_embed = nn.Sequential(nn.Linear(settings['sentiment_feat_dim'], settings['feat_emb_dim']), 145 | # nn.LayerNorm(settings['feat_emb_dim'])) 146 | 147 | self.att_lstm = nn.LSTMCell(settings['rnn_hid_dim'] + settings['feat_emb_dim'] + settings['word_emb_dim'], 148 | settings['rnn_hid_dim']) # h^2_t-1, fc, we 149 | self.att2att = nn.Sequential(nn.Linear(settings['feat_emb_dim'], settings['att_hid_dim']), 150 | nn.ReLU()) 151 | self.senti2att = nn.Sequential(nn.Linear(settings['word_emb_dim'], settings['att_hid_dim']), 152 | nn.ReLU()) 153 | self.attention = Attention(settings) 154 | 155 | # TODO now: word_emb_dim == feat_emb_dim 156 | # self.senti2feat = nn.Sequential(nn.Linear(settings['word_emb_dim'], settings['feat_emb_dim']), 157 | # nn.ReLU()) 158 | self.lang_lstm = nn.LSTMCell(settings['rnn_hid_dim'] + settings['feat_emb_dim'], 159 | settings['rnn_hid_dim']) # \hat v, h^1_t 160 | 161 | self.classifier = nn.Linear(settings['rnn_hid_dim'], self.vocab_size) 162 | 163 | def init_hidden(self, bsz): 164 | weight = next(self.parameters()) 165 | return (weight.new_zeros([2, bsz, self.att_lstm.hidden_size]), # h_att, h_lang 166 | weight.new_zeros([2, bsz, self.att_lstm.hidden_size])) # c_att, c_lang 167 | 168 | def forward_step(self, it, state, fc_feats, att_feats=None, p_att_feats=None, 169 | senti_word_feats=None, p_senti_word_feats=None, senti_labels=None): 170 | xt = self.word_embed(it) 171 | if senti_labels is not None: 172 | xt = xt + senti_labels 173 | prev_h = state[0][1] # [bs, rnn_hid] 174 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) # [bs, rnn_hid+feat_emb+word_emb] 175 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) # [bs, rnn_hid] 176 | 177 | att = self.attention(h_att, att_feats, p_att_feats, senti_word_feats, 178 | p_senti_word_feats, senti_labels) # [bs, feat_emb+word_emb] 179 | 180 | lang_lstm_input = torch.cat([att, h_att], 1) # [bs, feat_emb+rnn_hid] 181 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) # bs*rnn_hid 182 | output = self.drop(h_lang) # [bs, rnn_hid] 183 | logprobs = F.log_softmax(self.classifier(output), dim=1) # [bs, vocab] 184 | 185 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 186 | return logprobs, state 187 | 188 | def forward(self, *args, **kwargs): 189 | mode = kwargs.get('mode', 'xe') 190 | if 'mode' in kwargs: 191 | del kwargs['mode'] 192 | return getattr(self, 'forward_' + mode)(*args, **kwargs) 193 | 194 | def forward_xe(self, fc_feats, att_feats, cpt_words, captions, senti_labels, ss_prob=0.0): 195 | batch_size = fc_feats.size(0) 196 | outputs = [] 197 | 198 | fc_feats = self.fc_embed(fc_feats) # [bs, feat_emb] 199 | self.fc_feats = fc_feats 200 | fc_feats = self.drop(fc_feats) 201 | cpt_feats = self.word_embed(cpt_words) # [bs, num_cpts, word_emb] 202 | cpt_feats = cpt_feats.mean(dim=1) # [bs, word_emb] 203 | cpt_feats = self.cpt2fc(cpt_feats) # [bs, feat_emb] 204 | self.cpt_feats = cpt_feats 205 | # TODO 206 | # cpt_feats = self.drop(cpt_feats) 207 | 208 | att_feats = att_feats.view(batch_size, -1, att_feats.shape[-1]) # [bs, num_atts, att_feat] 209 | att_feats = self.att_embed(att_feats) # [bs, num_atts, feat_emb] 210 | att_feats = self.drop(att_feats) 211 | p_att_feats = self.att2att(att_feats) # [bs, num_atts, att_hid] 212 | 213 | senti_labels = self.senti_label_embed(senti_labels) # [bs, word_emb] 214 | senti_labels = self.drop(senti_labels) 215 | 216 | state = self.init_hidden(batch_size) 217 | 218 | for i in range(captions.size(1) - 1): 219 | if self.training and i >= 1 and ss_prob > 0.0: # otherwise no need to sample 220 | sample_prob = fc_feats.new(batch_size).uniform_(0, 1) 221 | sample_mask = sample_prob < ss_prob 222 | if sample_mask.sum() == 0: 223 | it = captions[:, i].clone() # bs 224 | else: 225 | sample_ind = sample_mask.nonzero(as_tuple=False).view(-1) 226 | it = captions[:, i].clone() # bs 227 | prob_prev = outputs[i - 1].detach().exp() # bs*vocab_size, fetch prev distribution 228 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 229 | else: 230 | it = captions[:, i].clone() # bs 231 | 232 | output, state = self.forward_step( 233 | it, state, fc_feats, att_feats, p_att_feats, senti_labels=senti_labels) 234 | outputs.append(output) 235 | 236 | self.cont_weights, self.senti_weights, self.cont_senti_weights = \ 237 | self.attention._get_weights() 238 | 239 | outputs = torch.stack(outputs, dim=1) # [bs, max_len, vocab_size] 240 | return outputs 241 | 242 | def forward_seq2seq(self, senti_captions, cpt_words, senti_words, senti_labels, 243 | ss_prob=0.0): 244 | batch_size = senti_captions.size(0) 245 | outputs = [] 246 | 247 | cpt_feats = self.word_embed(cpt_words) # [bs, num_cpts, word_emb] 248 | cpt_feats = cpt_feats.mean(dim=1) # [bs, word_emb] 249 | cpt_feats = self.cpt2fc(cpt_feats) # [bs, feat_emb] 250 | cpt_feats = self.drop(cpt_feats) 251 | fc_feats = cpt_feats 252 | 253 | senti_words = torch.cat( 254 | [senti_words.new_zeros(batch_size, 1).fill_(self.pad_id), senti_words], 255 | dim=1) # [bs, num_stmts] 256 | senti_word_feats = self.word_embed(senti_words) # [bs, num_stmts, word_emb] 257 | senti_word_feats = self.drop(senti_word_feats) 258 | p_senti_word_feats = self.senti2att(senti_word_feats) # [bs, num_stmts, att_hid] 259 | 260 | senti_labels = self.senti_label_embed(senti_labels) # [bs, word_emb] 261 | senti_labels = self.drop(senti_labels) 262 | 263 | state = self.init_hidden(batch_size) 264 | 265 | for i in range(senti_captions.size(1) - 1): 266 | if self.training and i >= 1 and ss_prob > 0.0: # otherwise no need to sample 267 | sample_prob = fc_feats.new(batch_size).uniform_(0, 1) 268 | sample_mask = sample_prob < ss_prob 269 | if sample_mask.sum() == 0: 270 | it = senti_captions[:, i].clone() # bs 271 | else: 272 | sample_ind = sample_mask.nonzero(as_tuple=False).view(-1) 273 | it = senti_captions[:, i].clone() # bs 274 | prob_prev = outputs[i - 1].detach().exp() # bs*vocab_size, fetch prev distribution 275 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 276 | else: 277 | it = senti_captions[:, i].clone() # bs 278 | 279 | output, state = self.forward_step( 280 | it, state, fc_feats, senti_word_feats=senti_word_feats, 281 | p_senti_word_feats=p_senti_word_feats, senti_labels=senti_labels) 282 | outputs.append(output) 283 | 284 | self.cont_weights, self.senti_weights, self.cont_senti_weights = \ 285 | self.attention._get_weights() 286 | 287 | outputs = torch.stack(outputs, dim=1) # [bs, max_len, vocab_size] 288 | return outputs 289 | 290 | def forward_rl(self, fc_feats, att_feats, cpt_words, senti_words, senti_labels, 291 | max_seq_len, sample_max): 292 | batch_size = fc_feats.shape[0] 293 | 294 | fc_feats = self.fc_embed(fc_feats) # [bs, feat_emb] 295 | self.fc_feats = fc_feats 296 | fc_feats = self.drop(fc_feats) 297 | cpt_feats = self.word_embed(cpt_words) # [bs, num_cpts, word_emb] 298 | cpt_feats = cpt_feats.mean(dim=1) # [bs, word_emb] 299 | cpt_feats = self.cpt2fc(cpt_feats) # [bs, feat_emb] 300 | self.cpt_feats = cpt_feats 301 | 302 | att_feats = att_feats.view(batch_size, -1, att_feats.shape[-1]) # [bs, num_atts, att_feat] 303 | att_feats = self.att_embed(att_feats) # [bs, num_atts, feat_emb] 304 | att_feats = self.drop(att_feats) 305 | p_att_feats = self.att2att(att_feats) # [bs, num_atts, att_hid] 306 | 307 | senti_words = torch.cat( 308 | [senti_words.new_zeros(batch_size, 1).fill_(self.pad_id), senti_words], 309 | dim=1) # [bs, num_stmts] 310 | senti_word_feats = self.word_embed(senti_words) # [bs, num_stmts, word_emb] 311 | senti_word_feats = self.drop(senti_word_feats) 312 | p_senti_word_feats = self.senti2att(senti_word_feats) # [bs, num_stmts, att_hid] 313 | 314 | senti_labels = self.senti_label_embed(senti_labels) # [bs, word_emb] 315 | senti_labels = self.drop(senti_labels) 316 | 317 | state = self.init_hidden(batch_size) 318 | seq = fc_feats.new_zeros((batch_size, max_seq_len), dtype=torch.long) 319 | seq_logprobs = fc_feats.new_zeros((batch_size, max_seq_len)) 320 | seq_masks = fc_feats.new_zeros((batch_size, max_seq_len)) 321 | it = fc_feats.new_zeros(batch_size, dtype=torch.long).fill_(self.sos_id) # first input 322 | unfinished = it == self.sos_id 323 | for t in range(max_seq_len): 324 | logprobs, state = self.forward_step( 325 | it, state, fc_feats, att_feats, p_att_feats, 326 | senti_word_feats, p_senti_word_feats, senti_labels) 327 | 328 | if sample_max: 329 | sample_logprobs, it = torch.max(logprobs, 1) 330 | else: 331 | prob_prev = torch.exp(logprobs) 332 | it = torch.multinomial(prob_prev, 1) 333 | sample_logprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 334 | it = it.view(-1).long() 335 | sample_logprobs = sample_logprobs.view(-1) 336 | 337 | seq_masks[:, t] = unfinished 338 | it = it * unfinished.type_as(it) # bs 339 | seq[:, t] = it 340 | seq_logprobs[:, t] = sample_logprobs 341 | 342 | unfinished = unfinished * (it != self.eos_id) 343 | if unfinished.sum() == 0: 344 | break 345 | 346 | self.cont_weights, self.senti_weights, self.cont_senti_weights = \ 347 | self.attention._get_weights() 348 | 349 | return seq, seq_logprobs, seq_masks 350 | 351 | def sample(self, fc_feat, att_feat, senti_words=None, senti_label=None, 352 | beam_size=3, decoding_constraint=1, max_seq_len=16): 353 | self.eval() 354 | fc_feats = fc_feat.view(1, -1) # [1, fc_feat] 355 | att_feats = att_feat.view(1, -1, att_feat.shape[-1]) # [1, num_atts, att_feat] 356 | 357 | fc_feats = self.fc_embed(fc_feats) # [bs, feat_emb] 358 | fc_feats = self.drop(fc_feats) 359 | 360 | att_feats = self.att_embed(att_feats) # [bs, num_atts, feat_emb] 361 | att_feats = self.drop(att_feats) 362 | p_att_feats = self.att2att(att_feats) # [bs, num_atts, att_hid] 363 | 364 | if senti_words is not None: 365 | senti_words = senti_words.view(1, -1) 366 | senti_words = torch.cat( 367 | [senti_words.new_zeros(1, 1).fill_(self.pad_id), senti_words], 368 | dim=1) # [bs, num_stmts] 369 | senti_word_feats = self.word_embed(senti_words) # [bs, num_stmts, word_emb] 370 | senti_word_feats = self.drop(senti_word_feats) 371 | p_senti_word_feats = self.senti2att(senti_word_feats) # [bs, num_stmts, att_hid] 372 | 373 | senti_labels = self.senti_label_embed(senti_label) # [bs, word_emb] 374 | senti_labels = self.drop(senti_labels) 375 | else: 376 | senti_word_feats = p_senti_word_feats = senti_labels = None 377 | 378 | state = self.init_hidden(1) 379 | candidates = [BeamCandidate(state, 0., [], self.sos_id, [])] 380 | for t in range(max_seq_len): 381 | tmp_candidates = [] 382 | end_flag = True 383 | for candidate in candidates: 384 | state, log_prob_sum, log_prob_seq, last_word_id, word_id_seq = candidate 385 | if t > 0 and last_word_id == self.eos_id: 386 | tmp_candidates.append(candidate) 387 | else: 388 | end_flag = False 389 | it = fc_feats.type(torch.long).new_tensor([last_word_id]) 390 | logprobs, state = self.forward_step( 391 | it, state, fc_feats, att_feats, p_att_feats, 392 | senti_word_feats, p_senti_word_feats, senti_labels) # [1, vocab_size] 393 | logprobs = logprobs.squeeze(0) # vocab_size 394 | if self.pad_id != self.eos_id: 395 | logprobs[self.pad_id] = float('-inf') # do not generate and 396 | logprobs[self.sos_id] = float('-inf') 397 | logprobs[self.unk_id] = float('-inf') 398 | if decoding_constraint: # do not generate last step word 399 | logprobs[last_word_id] = float('-inf') 400 | 401 | output_sorted, index_sorted = torch.sort(logprobs, descending=True) 402 | for k in range(beam_size): 403 | log_prob, word_id = output_sorted[k], index_sorted[k] # tensor, tensor 404 | log_prob = float(log_prob) 405 | word_id = int(word_id) 406 | tmp_candidates.append(BeamCandidate(state, log_prob_sum + log_prob, 407 | log_prob_seq + [log_prob], 408 | word_id, word_id_seq + [word_id])) 409 | candidates = sorted(tmp_candidates, key=lambda x: x.log_prob_sum, reverse=True)[:beam_size] 410 | if end_flag: 411 | break 412 | 413 | self.cont_weights, self.senti_weights, self.cont_senti_weights = \ 414 | self.attention._get_weights() 415 | 416 | # captions, scores 417 | captions = [' '.join([self.idx2word[idx] for idx in candidate.word_id_seq if idx != self.eos_id]) 418 | for candidate in candidates] 419 | scores = [candidate.log_prob_sum for candidate in candidates] 420 | return captions, scores 421 | 422 | def get_optim_criterion(self, lr, weight_decay=0): 423 | return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay),\ 424 | XECriterion(), nn.MSELoss() # xe, domain align 425 | 426 | 427 | class XECriterion(nn.Module): 428 | def __init__(self): 429 | super(XECriterion, self).__init__() 430 | 431 | def forward(self, pred, target, lengths): 432 | max_len = max(lengths) 433 | mask = pred.new_zeros(len(lengths), max_len) 434 | for i, l in enumerate(lengths): 435 | mask[i, :l] = 1 436 | 437 | loss = - pred.gather(2, target.unsqueeze(2)).squeeze(2) * mask 438 | loss = torch.sum(loss) / torch.sum(mask) 439 | 440 | return loss 441 | -------------------------------------------------------------------------------- /models/concept_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConceptDetector(nn.Module): 6 | def __init__(self, idx2concept, settings): 7 | super(ConceptDetector, self).__init__() 8 | self.idx2concept = idx2concept 9 | 10 | self.output = nn.Sequential( 11 | nn.Linear(settings['fc_feat_dim'], settings['concept_mid_him']), 12 | nn.ReLU(), 13 | nn.Linear(settings['concept_mid_him'], settings['concept_mid_him']), 14 | nn.ReLU(), 15 | nn.Dropout(settings['dropout_p']), 16 | nn.Linear(settings['concept_mid_him'], len(idx2concept)), 17 | nn.Sigmoid(), 18 | ) 19 | 20 | def forward(self, features): 21 | # [bz, fc_feat_dim] 22 | return self.output(features) # [bz, num_cpts] 23 | 24 | def sample(self, features, num): 25 | # [bz, fc_feat_dim] 26 | self.eval() 27 | out = self.output(features) # [bz, num_cpts] 28 | scores, idx = out.sort(dim=-1, descending=True) 29 | scores = scores[:, :num] 30 | idx = idx[:, :num] 31 | concepts = [] 32 | for batch in idx: 33 | tmp = [] 34 | for i in batch: 35 | tmp.append(self.idx2concept[i]) 36 | concepts.append(tmp) 37 | return out, concepts, scores 38 | 39 | def get_optim_criterion(self, lr, weight_decay=0): 40 | return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay),\ 41 | MultiLabelClsLoss() 42 | 43 | 44 | class MultiLabelClsLoss(nn.Module): 45 | def __init__(self): 46 | super(MultiLabelClsLoss, self).__init__() 47 | 48 | def forward(self, result, target): 49 | # result/target: [bz, num_cpts] 50 | target = target.type(result.type()) 51 | 52 | output = target * result.log() 53 | output = - output.mean(dim=-1).mean(dim=-1) # scalar 54 | out = (1 - target) * (1 - result).log() 55 | out = - out.mean(dim=-1).mean(dim=-1) # scalar 56 | 57 | output += out 58 | return output 59 | -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | from collections import defaultdict 5 | 6 | from .captioner import Captioner 7 | from .sentiment_detector import SentimentDetector 8 | from .sent_senti_cls import SentenceSentimentClassifier 9 | 10 | from self_critical.utils import get_ciderd_scorer, get_self_critical_reward, \ 11 | get_lm_reward, RewardCriterion, get_cls_reward, get_senti_words_reward 12 | 13 | 14 | def clip_gradient(optimizer, grad_clip=0.1): 15 | for group in optimizer.param_groups: 16 | for param in group['params']: 17 | if param.grad is not None: 18 | param.grad.data.clamp_(-grad_clip, grad_clip) 19 | 20 | 21 | class Detector(nn.Module): 22 | def __init__(self, idx2word, max_seq_len, sentiment_categories, lrs, settings): 23 | super(Detector, self).__init__() 24 | self.idx2word = idx2word 25 | self.pad_id = idx2word.index('') 26 | self.max_seq_len = max_seq_len 27 | 28 | self.captioner = Captioner(idx2word, sentiment_categories, settings) 29 | self.senti_detector = SentimentDetector(sentiment_categories, settings) 30 | self.sent_senti_cls = SentenceSentimentClassifier(idx2word, sentiment_categories, settings) 31 | self.senti_detector.eval() 32 | self.sent_senti_cls.eval() 33 | 34 | self.cap_optim, self.cap_xe_crit, self.cap_da_crit = self.captioner.get_optim_criterion(lrs['cap_lr']) 35 | self.cap_rl_crit = RewardCriterion() 36 | # self.senti_optim, self.senti_crit = self.senti_detector.get_optim_criterion(lrs['senti_lr']) 37 | # self.sent_optim, self.sent_crit = self.sent_senti_cls.get_optim_and_crit(lrs['sent_lr']) 38 | 39 | self.cls_flag = 0.4 40 | self.seq_flag = 1.0 41 | self.senti_threshold = 0.7 42 | 43 | def set_ciderd_scorer(self, captions): 44 | self.ciderd_scorer = get_ciderd_scorer(captions, self.captioner.sos_id, self.captioner.eos_id) 45 | 46 | def set_sentiment_words(self, sentiment_words): 47 | self.sentiment_words = sentiment_words 48 | 49 | def set_lms(self, lms): 50 | self.lms = lms 51 | 52 | def forward(self, data, data_type, training): 53 | self.captioner.train(training) 54 | all_losses = defaultdict(float) 55 | device = next(self.parameters()).device 56 | 57 | # if data_type == 'senti': 58 | # cls_flag = 1 59 | # else: 60 | # cls_flag = self.cls_flag 61 | if training: 62 | seq2seq_data = iter(data[1]) 63 | # accur_ws = defaultdict(set) 64 | caption_data = iter(data[0]) 65 | for _ in tqdm.tqdm(range(min(500, len(data[0])))): 66 | data_item = next(caption_data) 67 | if data_type == 'fact': 68 | fns, fc_feats, att_feats, (caps_tensor, lengths), cpts_tensor, sentis_tensor, ground_truth = data_item 69 | caps_tensor = caps_tensor.to(device) 70 | elif data_type == 'senti': 71 | fns, fc_feats, att_feats, cpts_tensor, sentis_tensor, senti_labels = data_item 72 | senti_labels = senti_labels.to(device) 73 | else: 74 | raise Exception('data_type(%s) is wrong!' % data_type) 75 | 76 | fc_feats = fc_feats.to(device) 77 | att_feats = att_feats.to(device) 78 | cpts_tensor = cpts_tensor.to(device) 79 | sentis_tensor = sentis_tensor.to(device) 80 | del data_item 81 | 82 | if data_type == 'fact' or not training: 83 | senti_labels, _, _, _ = self.senti_detector.sample(att_feats, self.senti_threshold) 84 | senti_labels = senti_labels.detach() 85 | 86 | sample_captions, sample_logprobs, seq_masks = self.captioner( 87 | fc_feats, att_feats, cpts_tensor, sentis_tensor, senti_labels, 88 | self.max_seq_len, sample_max=0, mode='rl') 89 | da_loss = self.cap_da_crit(self.captioner.cpt_feats, self.captioner.fc_feats.detach()) 90 | all_losses['da_loss'] += float(da_loss) 91 | 92 | self.captioner.eval() 93 | with torch.no_grad(): 94 | greedy_captions, _, greedy_masks = self.captioner( 95 | fc_feats, att_feats, cpts_tensor, sentis_tensor, senti_labels, 96 | self.max_seq_len, sample_max=1, mode='rl') 97 | self.captioner.train(training) 98 | 99 | if data_type == 'fact': 100 | fact_reward = get_self_critical_reward( 101 | sample_captions, greedy_captions, fns, ground_truth, 102 | self.captioner.sos_id, self.captioner.eos_id, self.ciderd_scorer) 103 | fact_reward = torch.from_numpy(fact_reward).float().to(device) 104 | all_losses['fact_reward'] += float(fact_reward[:, 0].mean()) 105 | else: 106 | fact_reward = 0 107 | 108 | cls_reward = get_cls_reward( 109 | sample_captions, seq_masks, greedy_captions, greedy_masks, 110 | senti_labels, self.sent_senti_cls) # [bs, num_sentis] 111 | cls_reward = torch.from_numpy(cls_reward).float().to(device) 112 | all_losses['cls_reward'] += float(cls_reward.mean(-1).mean(-1)) 113 | 114 | # lm_reward = get_lm_reward( 115 | # sample_captions, greedy_captions, senti_labels, 116 | # self.captioner.sos_id, self.captioner.eos_id, self.lms) 117 | # lm_reward = torch.from_numpy(lm_reward).float().to(device) 118 | # all_losses['lm_reward'] += float(lm_reward[:, 0].sum()) 119 | 120 | # senti_words_reward, accur_w = get_senti_words_reward(sample_captions, senti_labels, self.sentiment_words) 121 | # for senti, words in accur_w.items(): 122 | # accur_ws[senti].update(words) 123 | # senti_words_reward = torch.from_numpy(senti_words_reward).float().to(device) 124 | # all_losses['senti_words_reward'] += float(senti_words_reward.sum()) 125 | 126 | rewards = fact_reward + self.cls_flag * cls_reward # + 0.05 * senti_words_reward 127 | all_losses['all_rewards'] += float(rewards.mean(-1).mean(-1)) 128 | cap_loss = self.cap_rl_crit(sample_logprobs, seq_masks, rewards) 129 | all_losses['cap_loss'] += float(cap_loss) 130 | 131 | xe_loss = 0.0 132 | if data_type == 'fact': 133 | with torch.no_grad(): 134 | xe_senti_labels, _ = self.sent_senti_cls(caps_tensor[:, 1:], lengths) 135 | xe_senti_labels = xe_senti_labels.softmax(dim=-1) 136 | xe_senti_labels = xe_senti_labels.argmax(dim=-1).detach() 137 | 138 | pred = self.captioner(fc_feats, att_feats, cpts_tensor, caps_tensor, 139 | xe_senti_labels, ss_prob=0.5, mode='xe') 140 | xe_loss = self.cap_xe_crit(pred, caps_tensor[:, 1:], lengths) 141 | all_losses['xe_loss'] += float(xe_loss) 142 | 143 | seq2seq_loss = 0.0 144 | if training: 145 | try: 146 | (caps_tensor, lengths), cpts_tensor, sentis_tensor, senti_labels = next(seq2seq_data) 147 | except: 148 | seq2seq_data = iter(data[1]) 149 | (caps_tensor, lengths), cpts_tensor, sentis_tensor, senti_labels = next(seq2seq_data) 150 | caps_tensor = caps_tensor.to(device) 151 | cpts_tensor = cpts_tensor.to(device) 152 | sentis_tensor = sentis_tensor.to(device) 153 | senti_labels = senti_labels.to(device) 154 | 155 | pred = self.captioner(caps_tensor, cpts_tensor, sentis_tensor, senti_labels, ss_prob=0.25, 156 | mode='seq2seq') 157 | seq2seq_loss = self.cap_xe_crit(pred, caps_tensor[:, 1:], lengths) 158 | seq2seq_loss = self.seq_flag * seq2seq_loss 159 | all_losses['seq2seq_loss'] += float(seq2seq_loss) 160 | 161 | cap_loss = cap_loss + xe_loss + da_loss + seq2seq_loss 162 | 163 | if training: 164 | self.cap_optim.zero_grad() 165 | cap_loss.backward() 166 | clip_gradient(self.cap_optim) 167 | self.cap_optim.step() 168 | 169 | # if training and data_type == 'fact': 170 | # self.cls_flag = self.cls_flag * 2 171 | # if self.cls_flag > 1.0: 172 | # self.cls_flag = 1.0 173 | 174 | # for senti, words in accur_ws.items(): 175 | # for w in words: 176 | # self.sentiment_words[senti][w] *= 0.9 177 | 178 | for k, v in all_losses.items(): 179 | all_losses[k] = v / len(data) 180 | return all_losses 181 | 182 | def sample(self, fc_feats, att_feats, sentis_tensor, 183 | beam_size=3, decoding_constraint=1): 184 | self.eval() 185 | att_feats = att_feats.unsqueeze(0) 186 | senti_label, _, det_img_sentis, _ = self.senti_detector.sample(att_feats, self.senti_threshold) 187 | 188 | captions, _ = self.captioner.sample( 189 | fc_feats, att_feats, sentis_tensor, senti_label, 190 | beam_size, decoding_constraint, self.max_seq_len) 191 | 192 | return captions, det_img_sentis 193 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from torchvision.models.resnet import Bottleneck 6 | import numpy as np 7 | 8 | 9 | class ResNet(torchvision.models.resnet.ResNet): 10 | def __init__(self, block, layers, num_classes=1000): 11 | super(ResNet, self).__init__(block, layers, num_classes) 12 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 13 | for i in range(2, 5): 14 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 15 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 16 | 17 | 18 | class Encoder(nn.Module): 19 | def __init__(self, resnet101_file): 20 | super(Encoder, self).__init__() 21 | resnet = ResNet(Bottleneck, [3, 4, 23, 3]) 22 | ckpt = torch.load(resnet101_file, map_location=lambda s, l: s) 23 | resnet.load_state_dict(ckpt) 24 | self.resnet = resnet 25 | self.transforms = torchvision.transforms.Compose([ 26 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 27 | ]) 28 | 29 | def preprocess(self, image): 30 | if len(image.shape) == 2: 31 | image = image[:, :, np.newaxis] 32 | image = np.concatenate((image, image, image), axis=2) 33 | 34 | image = image.astype('float32') / 255.0 35 | image = torch.from_numpy(image.transpose(2, 0, 1)) 36 | image = self.transforms(image) 37 | return image 38 | 39 | def forward(self, img, att_size=14): 40 | x = img.unsqueeze(0) 41 | 42 | x = self.resnet.conv1(x) 43 | x = self.resnet.bn1(x) 44 | x = self.resnet.relu(x) 45 | x = self.resnet.maxpool(x) 46 | 47 | x = self.resnet.layer1(x) 48 | x = self.resnet.layer2(x) 49 | x = self.resnet.layer3(x) 50 | x = self.resnet.layer4(x) 51 | 52 | fc = x.mean(3).mean(2).squeeze() 53 | att = F.adaptive_avg_pool2d(x, [att_size, att_size]).squeeze().permute(1, 2, 0) 54 | 55 | return fc, att 56 | -------------------------------------------------------------------------------- /models/sent_senti_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class SentenceSentimentClassifier(nn.Module): 7 | def __init__(self, idx2word, sentiment_categories, settings): 8 | super(SentenceSentimentClassifier, self).__init__() 9 | self.sentiment_categories = sentiment_categories 10 | self.pad_id = idx2word.index('') 11 | self.vocab_size = len(idx2word) 12 | self.word_embed = nn.Sequential(nn.Embedding(self.vocab_size, settings['word_emb_dim'], 13 | padding_idx=self.pad_id), 14 | nn.ReLU(), 15 | nn.Dropout(settings['dropout_p'])) 16 | 17 | rnn_bidirectional = False 18 | self.rnn = nn.LSTM(settings['word_emb_dim'], settings['rnn_hid_dim'], bidirectional=rnn_bidirectional) 19 | self.drop = nn.Dropout(settings['dropout_p']) 20 | if rnn_bidirectional: 21 | rnn_hid_dim = 2*settings['rnn_hid_dim'] 22 | else: 23 | rnn_hid_dim = settings['rnn_hid_dim'] 24 | self.excitation = nn.Sequential( 25 | nn.Linear(rnn_hid_dim, rnn_hid_dim), 26 | nn.ReLU(), 27 | nn.Linear(rnn_hid_dim, rnn_hid_dim), 28 | nn.Sigmoid(), 29 | ) 30 | self.squeeze = nn.AdaptiveAvgPool1d(1) 31 | self.sent_senti_cls = nn.Sequential( 32 | nn.Linear(rnn_hid_dim, rnn_hid_dim), 33 | nn.ReLU(), 34 | nn.Dropout(settings['dropout_p']), 35 | nn.Linear(rnn_hid_dim, len(sentiment_categories)), 36 | ) 37 | 38 | def forward(self, seqs, lengths): 39 | seqs = self.word_embed(seqs) # [bs, max_seq_len, word_dim] 40 | seqs = pack_padded_sequence(seqs, lengths, batch_first=True, enforce_sorted=False) 41 | out, _ = self.rnn(seqs) 42 | out = pad_packed_sequence(out, batch_first=True)[0] # [bs, seq_len, rnn_hid] 43 | out = self.drop(out) 44 | 45 | excitation_res = self.excitation(out) # [bs, max_len, rnn_hid] 46 | # excitation_res = self.drop(excitation_res) 47 | excitation_res = pad_packed_sequence( 48 | pack_padded_sequence(excitation_res, lengths, batch_first=True, enforce_sorted=False), 49 | batch_first=True)[0] 50 | squeeze_res = self.squeeze(excitation_res).permute(0, 2, 1) # [bs, 1, max_len] 51 | # squeeze_res = squeeze_res.masked_fill(squeeze_res == 0, -1e10) 52 | # squeeze_res = squeeze_res.softmax(dim=-1) 53 | sent_feats = squeeze_res.bmm(out).squeeze(dim=1) # [bs, rnn_hid] 54 | pred = self.sent_senti_cls(sent_feats) # [bs, 3] 55 | 56 | return pred, squeeze_res.squeeze(dim=1) 57 | 58 | def sample(self, seqs, lengths): 59 | self.eval() 60 | pred, att_weights = self.forward(seqs, lengths) 61 | result = [] 62 | result_w = [] 63 | for p in pred: 64 | res = int(p.argmax(-1)) 65 | result.append(res) 66 | result_w.append(self.sentiment_categories[res]) 67 | 68 | return result, result_w, att_weights 69 | 70 | def get_optim_and_crit(self, lr, weight_decay=0): 71 | return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay), \ 72 | nn.CrossEntropyLoss() 73 | -------------------------------------------------------------------------------- /models/sentiment_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SentimentDetector(nn.Module): 6 | def __init__(self, sentiment_categories, settings): 7 | super(SentimentDetector, self).__init__() 8 | self.sentiment_categories = sentiment_categories 9 | self.neu_idx = sentiment_categories.index('neutral') 10 | 11 | self.convs = nn.Sequential() 12 | in_channels = settings['fc_feat_dim'] 13 | for i in range(settings['sentiment_convs_num']): 14 | self.convs.add_module( 15 | 'conv_%d' % i, nn.Conv2d(in_channels, in_channels // 2, 3, padding=1)) 16 | in_channels //= 2 17 | self.convs.add_module('dropout', nn.Dropout(settings['dropout_p'])) 18 | self.convs.add_module('relu', nn.ReLU()) 19 | 20 | # TODO: Can be modified for multiple kernels per sentiment 21 | num_sentis = len(sentiment_categories) 22 | self.senti_conv = nn.Conv2d(in_channels, num_sentis, 1) 23 | # self.global_pool = nn.AdaptiveMaxPool2d(1) 24 | self.global_pool = nn.AdaptiveAvgPool2d(1) 25 | 26 | self.output = nn.Sequential( 27 | *[nn.Linear(num_sentis, num_sentis) for _ in range(settings['sentiment_fcs_num'])] 28 | ) 29 | 30 | def forward(self, features): 31 | # [bz, 14, 14, fc_feat_dim] 32 | features = features.permute(0, 3, 1, 2) # [bz, fc_feat_dim, 14, 14] 33 | features = self.convs(features) # [bz, channels, 14, 14] 34 | senti_features = self.senti_conv(features) # [bz, num_sentis, 14, 14] 35 | features = self.global_pool(senti_features) # [bz, num_sentis, 1, 1] 36 | features = features.squeeze(-1).squeeze(-1) # [bz, num_sentis] 37 | output = self.output(features) # [bz, num_sentis] 38 | 39 | out = output.softmax(dim=-1) # [bz, num_sentis] 40 | shape = senti_features.shape 41 | senti_features = out.unsqueeze(1).bmm( 42 | senti_features.view(shape[0], shape[1], -1)) # [bz, 1, 14*14] 43 | senti_features = senti_features.view(shape[0], shape[2], shape[3]) # [bz, 14, 14] 44 | 45 | return output, senti_features 46 | 47 | def sample(self, features, senti_threshold=0): 48 | # [bz, 14, 14, fc_feat_dim] 49 | self.eval() 50 | output, senti_features = self.forward(features) 51 | output = output.softmax(dim=-1) 52 | scores, senti_labels = output.max(dim=-1) # bz 53 | replace_idx = (scores < senti_threshold).nonzero(as_tuple=False).view(-1) 54 | senti_labels.index_copy_(0, replace_idx, senti_labels.new_zeros(len(replace_idx)).fill_(self.neu_idx)) 55 | 56 | sentiments = [] 57 | for i in senti_labels: 58 | sentiments.append(self.sentiment_categories[i]) 59 | 60 | return senti_labels, senti_features, sentiments, scores 61 | 62 | def get_optim_criterion(self, lr, weight_decay=0): 63 | return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay),\ 64 | nn.CrossEntropyLoss() 65 | -------------------------------------------------------------------------------- /models/sentiment_detector_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SentimentDetector(nn.Module): 6 | def __init__(self, sentiment_categories, settings): 7 | super(SentimentDetector, self).__init__() 8 | self.sentiment_categories = sentiment_categories 9 | self.neu_idx = sentiment_categories.index('neutral') 10 | 11 | self.convs = nn.Sequential() 12 | in_channels = settings['fc_feat_dim'] 13 | for i in range(settings['sentiment_convs_num']): 14 | self.convs.add_module( 15 | 'conv_%d' % i, nn.Conv2d(in_channels, in_channels // 2, 3, padding=1)) 16 | in_channels //= 2 17 | self.convs.add_module('dropout', nn.Dropout(settings['dropout_p'])) 18 | self.convs.add_module('relu', nn.ReLU()) 19 | 20 | num_kernels_per_sentiment = settings['num_kernels_per_sentiment'] 21 | # TODO: Can be modified for multiple kernels per sentiment 22 | num_sentis = len(sentiment_categories) 23 | self.senti_conv = nn.Conv2d(in_channels, num_sentis * num_kernels_per_sentiment, 1) 24 | self.global_max_pool = nn.AdaptiveMaxPool2d(1) 25 | self.senti_pool = nn.AdaptiveAvgPool1d(num_sentis) 26 | 27 | self.senti_feat_pool = nn.AdaptiveAvgPool1d(num_sentis) 28 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1) 29 | 30 | self.cls = nn.Linear(2 * in_channels, num_sentis) 31 | 32 | def forward(self, features): 33 | # [bz, 14, 14, fc_feat_dim] 34 | features = features.permute(0, 3, 1, 2) # [bz, fc_feat_dim, 14, 14] 35 | features = self.convs(features) # [bz, n, 14, 14] 36 | senti_features = self.senti_conv(features) # [bz, k*C, 14, 14] 37 | pooled_vector = self.global_max_pool(senti_features) # [bz, k*C, 1, 1] 38 | pooled_vector = pooled_vector.squeeze(-1).permute(0, 2, 1) # [bz, 1, k*C] 39 | pooled_vector = self.senti_pool(pooled_vector) # [bz, 1, C] 40 | det_out = pooled_vector.squeeze(1) # [bz, C] 41 | 42 | weights = pooled_vector.softmax(dim=-1) # [bz, 1, C] 43 | shape = senti_features.shape 44 | senti_features = senti_features.reshape(shape[0], shape[1], -1).permute(0, 2, 1) # [bz, 14*14, k*C] 45 | senti_features = self.senti_feat_pool(senti_features) # [bz, 14*14, C] 46 | senti_features = weights.bmm(senti_features.permute(0, 2, 1)) # [bz, 1, 14*14] 47 | senti_features = senti_features.reshape(shape[0], 1, shape[2], shape[3]) # [bz, 1, 14, 14] 48 | 49 | semantic_features = torch.cat([features, features * senti_features.expand_as(features)], dim=1) # [bz, 2n, 14, 14] 50 | semantic_features = self.global_avg_pool(semantic_features) # [bz, 2n, 1, 1] 51 | semantic_features = semantic_features.squeeze(-1).squeeze(-1) # [bz, 2n] 52 | cls_out = self.cls(semantic_features) # [bz, C] 53 | 54 | return (det_out, cls_out), senti_features.squeeze(1) 55 | 56 | def sample(self, features, senti_threshold=0): 57 | # [bz, 14, 14, n] 58 | self.eval() 59 | (_, output), senti_features = self.forward(features) 60 | output = output.softmax(dim=-1) 61 | scores, senti_labels = output.max(dim=-1) # bz 62 | replace_idx = (scores < senti_threshold).nonzero(as_tuple=False).view(-1) 63 | senti_labels.index_copy_(0, replace_idx, senti_labels.new_zeros(len(replace_idx)).fill_(self.neu_idx)) 64 | 65 | sentiments = [] 66 | for i in senti_labels: 67 | sentiments.append(self.sentiment_categories[i]) 68 | 69 | return senti_labels, senti_features, sentiments, scores 70 | 71 | def get_optim_criterion(self, lr, weight_decay=0): 72 | return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay),\ 73 | nn.CrossEntropyLoss() 74 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | 5 | 6 | def parse_opt(): 7 | parser = argparse.ArgumentParser() 8 | 9 | # train settings 10 | # train concept detector 11 | parser.add_argument('--concept_lr', type=float, default=4e-4) 12 | parser.add_argument('--concept_bs', type=int, default=80) 13 | parser.add_argument('--concept_resume', type=str, default='') 14 | parser.add_argument('--concept_epochs', type=int, default=40) 15 | parser.add_argument('--concept_num_works', type=int, default=2) 16 | 17 | # train sentiment detector 18 | parser.add_argument('--senti_lr', type=float, default=4e-4) 19 | parser.add_argument('--senti_bs', type=int, default=80) 20 | parser.add_argument('--senti_resume', type=str, default='') 21 | parser.add_argument('--senti_epochs', type=int, default=30) 22 | parser.add_argument('--senti_num_works', type=int, default=2) 23 | 24 | parser.add_argument('--img_senti_labels', type=str, default='./data/captions/img_senti_labels.json') 25 | parser.add_argument('--sentiment_categories', type=list, default=['positive', 'negative', 'neutral']) 26 | 27 | # train full model 28 | # xe 29 | parser.add_argument('--xe_lr', type=float, default=4e-4) 30 | parser.add_argument('--xe_bs', type=int, default=20) 31 | parser.add_argument('--xe_resume', type=str, default='') 32 | parser.add_argument('--xe_epochs', type=int, default=40) 33 | parser.add_argument('--xe_num_works', type=int, default=2) 34 | 35 | parser.add_argument('--scheduled_sampling_start', type=int, default=0) 36 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=4) 37 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05) 38 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25) 39 | 40 | # rl 41 | parser.add_argument('--rl_lrs', type=json.loads, 42 | default='{"cap_lr": 4e-5}') # , "senti_lr": 4e-5, "sent_lr": 1e-3}') 43 | parser.add_argument('--rl_bs', type=int, default=40) 44 | parser.add_argument('--rl_num_works', type=int, default=2) 45 | parser.add_argument('--rl_resume', type=str, default='') 46 | parser.add_argument('--rl_senti_resume', type=str, default='checkpoint/sentiment/model-10.pth') 47 | parser.add_argument('--rl_epochs', type=int, default=40) 48 | parser.add_argument('--rl_fact_times', type=int, default=1) 49 | parser.add_argument('--rl_senti_times', type=int, default=0) 50 | 51 | # common 52 | parser.add_argument('--dataset_name', type=str, default='coco', choices=['coco', 'flickr30k']) 53 | parser.add_argument('--corpus_type', type=str, default='part', choices=['part', 'full']) 54 | parser.add_argument('--captions_dir', type=str, default='./data/captions') 55 | parser.add_argument('--feats_dir', type=str, default='./data/features') 56 | parser.add_argument('--corpus_dir', type=str, default='./data/corpus') 57 | parser.add_argument('--checkpoint', type=str, default='./checkpoint/') 58 | parser.add_argument('--result_dir', type=str, default='./result/') 59 | # parser.add_argument('--sentence_sentiment_classifier_rnn', type=str, default='') 60 | parser.add_argument('--max_seq_len', type=int, default=16) 61 | parser.add_argument('--num_concepts', type=int, default=5) 62 | parser.add_argument('--num_sentiments', type=int, default=10) 63 | parser.add_argument('--grad_clip', type=float, default=0.1) 64 | 65 | # eval settings 66 | parser.add_argument('-e', '--eval_model', type=str, default='') 67 | parser.add_argument('-r', '--result_file', type=str, default='') 68 | parser.add_argument('--beam_size', type=int, default=3) 69 | 70 | # test settings 71 | parser.add_argument('-t', '--test_model', type=str, default='') 72 | parser.add_argument('-i', '--image_file', type=str, default='') 73 | # encoder settings 74 | parser.add_argument('--resnet101_file', type=str, default='./data/pre_models/resnet101.pth', 75 | help='Pre-trained resnet101 network for extracting image features') 76 | 77 | args = parser.parse_args() 78 | 79 | # network settings 80 | settings = dict() 81 | settings['word_emb_dim'] = 512 82 | settings['fc_feat_dim'] = 2048 83 | settings['att_feat_dim'] = 2048 84 | settings['feat_emb_dim'] = 512 85 | settings['dropout_p'] = 0.5 86 | settings['rnn_hid_dim'] = 512 87 | settings['att_hid_dim'] = 512 88 | 89 | settings['concept_mid_him'] = 1024 90 | settings['sentiment_convs_num'] = 2 91 | # settings['num_kernels_per_sentiment'] = 4 92 | settings['sentiment_feat_dim'] = 14*14 93 | settings['sentiment_fcs_num'] = 2 94 | settings['text_cnn_filters'] = (3, 4, 5) 95 | settings['text_cnn_out_dim'] = 256 96 | 97 | args.settings = settings 98 | args.use_gpu = torch.cuda.is_available() 99 | args.device = torch.device('cuda:0') if args.use_gpu else torch.device('cpu') 100 | return args 101 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import sys 6 | import pdb 7 | import traceback 8 | from bdb import BdbQuit 9 | import h5py 10 | import tqdm 11 | from collections import Counter, defaultdict 12 | import skimage.io 13 | import nltk 14 | import torch 15 | from copy import deepcopy 16 | 17 | from models.encoder import Encoder 18 | 19 | 20 | concept_pos = ['VERB', 'NOUN'] 21 | 22 | 23 | def extract_imgs_feat(): 24 | encoder = Encoder(opt.resnet101_file) 25 | encoder.to(opt.device) 26 | encoder.eval() 27 | 28 | imgs = os.listdir(opt.imgs_dir) 29 | imgs.sort() 30 | 31 | if not os.path.exists(opt.feats_dir): 32 | os.makedirs(opt.feats_dir) 33 | with h5py.File(os.path.join(opt.feats_dir, 'feats_fc.h5')) as file_fc, \ 34 | h5py.File(os.path.join(opt.feats_dir, 'feats_att.h5')) as file_att: 35 | try: 36 | for img_nm in tqdm.tqdm(imgs): 37 | img = skimage.io.imread(os.path.join(opt.imgs_dir, img_nm)) 38 | if len(img.shape) == 3 and img.shape[-1] == 4: 39 | img = img[:, :, :3] 40 | with torch.no_grad(): 41 | img = encoder.preprocess(img) 42 | img = img.to(opt.device) 43 | img_fc, img_att = encoder(img) 44 | file_fc.create_dataset(img_nm, data=img_fc.cpu().float().numpy()) 45 | file_att.create_dataset(img_nm, data=img_att.cpu().float().numpy()) 46 | except BaseException as e: 47 | file_fc.close() 48 | file_att.close() 49 | print('--------------------------------------------------------------------') 50 | raise e 51 | 52 | 53 | def process_caption_datasets(): 54 | for dataset_nm in opt.dataset_names: 55 | print('===> process %s dataset' % dataset_nm) 56 | images = json.load(open(os.path.join(opt.caption_datasets_dir, 'dataset_%s.json' % dataset_nm), 'r'))['images'] 57 | img_captions = {'train': {}, 'val': {}, 'test': {}} 58 | img_captions_pos = {'train': {}, 'val': {}, 'test': {}} 59 | img_concepts = {'train': {}, 'val': {}, 'test': {}} 60 | for image in tqdm.tqdm(images): 61 | fn = image['filename'] 62 | split = image['split'] 63 | if split == 'restval': 64 | split = 'train' 65 | img_captions[split][fn] = [] 66 | img_captions_pos[split][fn] = [] 67 | img_concepts[split][fn] = set() 68 | sentences = [] 69 | for sentence in image['sentences']: 70 | raw = sentence['raw'].lower() 71 | words = nltk.word_tokenize(raw) 72 | sentences.append(words) 73 | tagged_sents = nltk.pos_tag_sents(sentences, tagset='universal') 74 | for tagged_tokens in tagged_sents: 75 | words = [] 76 | poses = [] 77 | for w, p in tagged_tokens: 78 | if p == '.': # remove punctuation 79 | continue 80 | words.append(w) 81 | poses.append(p) 82 | if p in concept_pos: 83 | img_concepts[split][fn].add(w) 84 | img_captions[split][fn].append(words) 85 | img_captions_pos[split][fn].append(poses) 86 | img_concepts[split][fn] = list(img_concepts[split][fn]) 87 | 88 | json.dump(img_captions, open(os.path.join(opt.captions_dir, dataset_nm, 'img_captions.json'), 'w')) 89 | json.dump(img_captions_pos, open(os.path.join(opt.captions_dir, dataset_nm, 'img_captions_pos.json'), 'w')) 90 | json.dump(img_concepts, open(os.path.join(opt.captions_dir, dataset_nm, 'img_concepts.json'), 'w')) 91 | 92 | 93 | def process_senti_corpus(): 94 | corpus_type = 'part' 95 | senti_corpus = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'senti_corpus.json'), 'r')) 96 | 97 | tmp_senti_corpus = defaultdict(list) 98 | tmp_senti_corpus_pos = defaultdict(list) 99 | all_sentis = Counter() 100 | sentis = defaultdict(Counter) 101 | sentiment_detector = defaultdict(Counter) 102 | 103 | for senti_label, sents in senti_corpus.items(): 104 | for i in tqdm.tqdm(range(0, len(sents), 100)): 105 | cur_sents = sents[i:i + 100] 106 | tmp_sents = [] 107 | for sent in cur_sents: 108 | tmp_sents.append(nltk.word_tokenize(sent.strip().lower())) 109 | tagged_sents = nltk.pos_tag_sents(tmp_sents, tagset='universal') 110 | for tagged_tokens in tagged_sents: 111 | words = [] 112 | poses = [] 113 | nouns = [] 114 | adjs = [] 115 | for w, p in tagged_tokens: 116 | if p == '.': # remove punctuation 117 | continue 118 | words.append(w) 119 | poses.append(p) 120 | if p == 'ADJ': 121 | adjs.append(w) 122 | elif p == 'NOUN': 123 | nouns.append(w) 124 | tmp_senti_corpus[senti_label].append(words) 125 | tmp_senti_corpus_pos[senti_label].append(poses) 126 | if adjs: 127 | all_sentis.update(adjs) 128 | sentis[senti_label].update(adjs) 129 | for noun in nouns: 130 | sentiment_detector[noun].update(adjs) 131 | 132 | json.dump(tmp_senti_corpus, open(os.path.join(opt.corpus_dir, corpus_type, 'tmp_senti_corpus.json'), 'w')) 133 | json.dump(tmp_senti_corpus_pos, open(os.path.join(opt.corpus_dir, corpus_type, 'tmp_senti_corpus_pos.json'), 'w')) 134 | 135 | all_sentis = all_sentis.most_common() 136 | all_sentis = [w for w in all_sentis if w[1] >= 3] 137 | sentis = {k: v.most_common() for k, v in sentis.items()} 138 | sentiment_detector = {k: v.most_common() for k, v in sentiment_detector.items()} 139 | 140 | all_sentis = {k: v for k, v in all_sentis} 141 | 142 | len_sentis = defaultdict(int) 143 | for k, v in sentis.items(): 144 | for _, n in v: 145 | len_sentis[k] += n 146 | tf_sentis = defaultdict(dict) 147 | tmp_sentis = defaultdict(dict) 148 | for k, v in sentis.items(): 149 | for w, n in v: 150 | tf_sentis[k][w] = n / len_sentis[k] 151 | tmp_sentis[k][w] = n 152 | sentis = tmp_sentis 153 | 154 | sentis_result = defaultdict(dict) 155 | for k, v in tf_sentis.items(): 156 | for w, tf in v.items(): 157 | if w in all_sentis: 158 | sentis_result[k][w] = tf * (sentis[k][w] / all_sentis[w]) 159 | 160 | sentiment_words = {} 161 | for k in sentis_result: 162 | sentiment_words[k] = list(sentis_result[k].items()) 163 | sentiment_words[k].sort(key=lambda p: p[1], reverse=True) 164 | sentiment_words[k] = [w[0] for w in sentiment_words[k]] 165 | 166 | common_rm = [] 167 | pos_rm = [] 168 | neg_rm = [] 169 | for i, w in enumerate(sentiment_words['positive']): 170 | if w in sentiment_words['negative']: 171 | n_idx = sentiment_words['negative'].index(w) 172 | if abs(i - n_idx) < 5: 173 | common_rm.append(w) 174 | elif i > n_idx: 175 | pos_rm.append(w) 176 | else: 177 | neg_rm.append(w) 178 | for w in common_rm: 179 | sentiment_words['positive'].remove(w) 180 | sentiment_words['negative'].remove(w) 181 | for w in pos_rm: 182 | sentiment_words['positive'].remove(w) 183 | for w in neg_rm: 184 | sentiment_words['negative'].remove(w) 185 | 186 | tmp_sentiment_words = {} 187 | for senti in sentiment_words: 188 | tmp_sentiment_words[senti] = {} 189 | for w in sentiment_words[senti]: 190 | tmp_sentiment_words[senti][w] = sentis_result[senti][w] 191 | sentiment_words = tmp_sentiment_words 192 | 193 | json.dump(sentiment_words, open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_words.json'), 'w')) 194 | 195 | tmp_sentiment_words = {} 196 | tmp_sentiment_words.update(sentiment_words['positive']) 197 | tmp_sentiment_words.update(sentiment_words['negative']) 198 | sentiment_words = tmp_sentiment_words 199 | 200 | tmp_sentiment_detector = defaultdict(list) 201 | for noun, senti_words in sentiment_detector.items(): 202 | number = sum([w[1] for w in senti_words]) 203 | for senti_word in senti_words: 204 | if senti_word[0] in sentiment_words: 205 | tmp_sentiment_detector[noun].append( 206 | (senti_word[0], senti_word[1] / number * sentiment_words[senti_word[0]])) 207 | sentiment_detector = tmp_sentiment_detector 208 | tmp_sentiment_detector = {} 209 | for noun, senti_words in sentiment_detector.items(): 210 | if len(senti_words) <= 50: 211 | tmp_sentiment_detector[noun] = senti_words 212 | 213 | json.dump(tmp_sentiment_detector, open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_detector.json'), 'w')) 214 | 215 | 216 | def build_idx2concept(): 217 | for dataset_nm in opt.dataset_names: 218 | img_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'img_concepts.json'), 'r')) 219 | tc = Counter() 220 | for concepts in img_concepts.values(): 221 | for cs in tqdm.tqdm(concepts.values()): 222 | tc.update(cs) 223 | tc = tc.most_common() 224 | idx2concept = [w[0] for w in tc[:2000]] 225 | json.dump(idx2concept, open(os.path.join(opt.captions_dir, dataset_nm, 'idx2concept.json'), 'w')) 226 | 227 | 228 | def get_img_senti_labels(): 229 | senti_img_fns = os.listdir(opt.senti_imgs_dir) 230 | senti_imgs = defaultdict(list) 231 | for fn in senti_img_fns: 232 | senti = fn.split('_')[0] 233 | senti_imgs[senti].append((fn, senti)) 234 | random.shuffle(senti_imgs['positive']) 235 | random.shuffle(senti_imgs['negative']) 236 | random.shuffle(senti_imgs['neutral']) 237 | img_senti_labels = {'train': [], 'val': [], 'test': []} 238 | img_senti_labels['val'].extend(senti_imgs['positive'][:100]) 239 | img_senti_labels['val'].extend(senti_imgs['negative'][:100]) 240 | img_senti_labels['val'].extend(senti_imgs['neutral'][:50]) 241 | img_senti_labels['test'].extend(senti_imgs['positive'][100:200]) 242 | img_senti_labels['test'].extend(senti_imgs['negative'][100:200]) 243 | img_senti_labels['test'].extend(senti_imgs['neutral'][50:100]) 244 | img_senti_labels['train'].extend(senti_imgs['positive'][200:]) 245 | img_senti_labels['train'].extend(senti_imgs['negative'][200:]) 246 | img_senti_labels['train'].extend(senti_imgs['neutral'][100:]) 247 | json.dump(img_senti_labels, open(opt.img_senti_labels, 'w')) 248 | 249 | 250 | def build_idx2word(): 251 | corpus_type = 'part' 252 | senti_corpus = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'tmp_senti_corpus.json'), 'r')) 253 | sentiment_words = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_words.json'), 'r')) 254 | idx2sentiment = [] 255 | for v in sentiment_words.values(): 256 | idx2sentiment.extend(list(v.keys())) 257 | 258 | for dataset_nm in opt.dataset_names: 259 | img_captions = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'img_captions.json'), 'r')) 260 | idx2concept = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'idx2concept.json'), 'r')) 261 | 262 | tc = Counter() 263 | for captions in img_captions.values(): 264 | for caps in captions.values(): 265 | for cap in caps: 266 | tc.update(cap) 267 | for captions in senti_corpus.values(): 268 | for cap in captions: 269 | tc.update(cap) 270 | tc = tc.most_common() 271 | idx2word = [w[0] for w in tc if w[1] > 5] 272 | 273 | idx2word.extend(idx2sentiment) 274 | idx2word.extend(idx2concept) 275 | idx2word = list(set(idx2word)) 276 | idx2word = ['', '', '', ''] + idx2word 277 | json.dump(idx2word, open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'idx2word.json'), 'w')) 278 | 279 | 280 | def get_img_det_sentiments(): 281 | corpus_type = 'part' 282 | sentiment_detector = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_detector.json'), 'r')) 283 | 284 | for dataset_nm in opt.dataset_names: 285 | det_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'img_det_concepts.json'), 'r')) 286 | det_sentiments = {} 287 | null_sentis = [] 288 | for fn, concepts in tqdm.tqdm(det_concepts.items()): 289 | sentis = [] 290 | for con in concepts: 291 | sentis.extend(sentiment_detector.get(con, [])) 292 | if sentis: 293 | tmp_sentis = defaultdict(float) 294 | for w, s in sentis: 295 | tmp_sentis[w] += s 296 | sentis = list(tmp_sentis.items()) 297 | sentis.sort(key=lambda p: p[1], reverse=True) 298 | sentis = [w[0] for w in sentis] 299 | else: 300 | null_sentis.append(fn) 301 | det_sentiments[fn] = sentis[:20] 302 | json.dump(det_sentiments, open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'img_det_sentiments.json'), 'w')) 303 | 304 | 305 | def get_senti_captions(): 306 | corpus_type = 'part' 307 | sentiment_detector = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_detector.json'), 'r')) 308 | senti_corpus = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'tmp_senti_corpus.json'), 'r')) 309 | senti_corpus_pos = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'tmp_senti_corpus_pos.json'), 'r')) 310 | sentiment_words = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_words.json'), 'r')) 311 | idx2sentiment = [] 312 | for v in sentiment_words.values(): 313 | idx2sentiment.extend(list(v.keys())) 314 | 315 | senti_captions = defaultdict(list) # len(pos) = 4633, len(neg) = 3760 316 | cpts_len = defaultdict(int) # len = 23, we choose 5 317 | sentis_len = defaultdict(int) # len = 104, we choose 5 or 10 318 | wrong = [] # len = 476 319 | for senti in senti_corpus: 320 | for i, cap in enumerate(senti_corpus[senti]): 321 | pos = senti_corpus_pos[senti][i] 322 | cpts = [] 323 | for j, p in enumerate(pos): 324 | if p in concept_pos: 325 | cpts.append(cap[j]) 326 | cpts = list(set(cpts)) 327 | sentis = [] 328 | for con in cpts: 329 | sentis.extend(sentiment_detector.get(con, [])) 330 | if sentis: 331 | tmp_sentis = defaultdict(float) 332 | for w, s in sentis: 333 | tmp_sentis[w] += s 334 | sentis = list(tmp_sentis.items()) 335 | sentis.sort(key=lambda p: p[1], reverse=True) 336 | sentis = [w[0] for w in sentis] 337 | senti_captions[senti].append([cap, cpts[:20], sentis[:20]]) 338 | cpts_len[len(cpts)] += 1 339 | sentis_len[len(sentis)] += 1 340 | else: 341 | wrong.append([len(cpts), len(sentis)]) 342 | cpts_len = list(cpts_len.items()) 343 | cpts_len.sort() 344 | sentis_len = list(sentis_len.items()) 345 | sentis_len.sort() 346 | 347 | for dataset_nm in opt.dataset_names: 348 | cpts_len = defaultdict(int) # len = 23, we choose 5 349 | sentis_len = defaultdict(int) # len = 104, we choose 5 or 10 350 | wrong = [] 351 | img_captions = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'img_captions.json'), 'r'))['train'] 352 | img_captions_pos = json.load(open(os.path.join(opt.captions_dir, dataset_nm, 'img_captions_pos.json'), 'r'))['train'] 353 | fact_caps = [] 354 | for fn, caps in tqdm.tqdm(img_captions.items()): 355 | for i, cap in enumerate(caps): 356 | flag = True 357 | for w in cap: 358 | if w in idx2sentiment: 359 | flag = False 360 | break 361 | if flag: 362 | pos = img_captions_pos[fn][i] 363 | cpts = [] 364 | for j, p in enumerate(pos): 365 | if p in concept_pos: 366 | cpts.append(cap[j]) 367 | cpts = list(set(cpts)) 368 | sentis = [] 369 | for con in cpts: 370 | sentis.extend(sentiment_detector.get(con, [])) 371 | if sentis: 372 | tmp_sentis = defaultdict(float) 373 | for w, s in sentis: 374 | tmp_sentis[w] += s 375 | sentis = list(tmp_sentis.items()) 376 | sentis.sort(key=lambda p: p[1], reverse=True) 377 | sentis = [w[0] for w in sentis] 378 | fact_caps.append([cap, cpts[:20], sentis[:20]]) 379 | cpts_len[len(cpts)] += 1 380 | sentis_len[len(sentis)] += 1 381 | else: 382 | wrong.append([len(cpts), len(sentis)]) 383 | cpts_len = list(cpts_len.items()) 384 | cpts_len.sort() 385 | sentis_len = list(sentis_len.items()) 386 | sentis_len.sort() 387 | 388 | tmp_senti_captions = deepcopy(senti_captions) 389 | tmp_senti_captions['neutral'] = fact_caps 390 | json.dump(tmp_senti_captions, open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'senti_captions.json'), 'w')) 391 | 392 | 393 | def get_anno_captions(): 394 | for dataset_nm in opt.dataset_names: 395 | images = json.load(open(os.path.join(opt.caption_datasets_dir, 'dataset_%s.json' % dataset_nm), 'r'))['images'] 396 | anno_captions = {} 397 | for image in tqdm.tqdm(images): 398 | if image['split'] == 'test': 399 | fn = image['filename'] 400 | sentences = [] 401 | for sentence in image['sentences']: 402 | raw = sentence['raw'].strip().lower() 403 | sentences.append(raw) 404 | anno_captions[fn] = sentences 405 | json.dump(anno_captions, open(os.path.join(opt.captions_dir, dataset_nm, 'anno_captions.json'), 'w')) 406 | 407 | 408 | def get_lm_sents(): 409 | corpus_type = 'part' 410 | for dataset_nm in opt.dataset_names: 411 | senti_captions = json.load(open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'senti_captions.json'), 'r')) 412 | for senti in senti_captions: 413 | senti_captions[senti] = [' '.join(c[0]) for c in senti_captions[senti]] 414 | senti_sents = defaultdict(str) 415 | for senti in senti_captions: 416 | for cap in senti_captions[senti]: 417 | senti_sents[senti] += cap + '\n' 418 | 419 | lm_dir = os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'lm') 420 | if not os.path.exists(lm_dir): 421 | os.makedirs(lm_dir) 422 | for senti in senti_sents: 423 | with open(os.path.join(lm_dir, '%s_w.txt' % senti), 'w') as f: 424 | f.write(senti_sents[senti]) 425 | 426 | count_cmd = 'ngram-count -text %s -order 3 -write %s' 427 | lm_cmd = 'ngram-count -read %s -order 3 -lm %s -interpolate -kndiscount' 428 | for dataset_nm in opt.dataset_names: 429 | lm_dir = os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'lm') 430 | fns = os.listdir(lm_dir) 431 | for fn in fns: 432 | if fn.endswith('_w.txt'): 433 | txt_file = os.path.join(lm_dir, fn) 434 | count_file = os.path.join(lm_dir, '%s.count' % fn.split('.')[0]) 435 | lm_file = os.path.join(lm_dir, '%s.sri' % fn.split('.')[0]) 436 | out = os.popen(count_cmd % (txt_file, count_file)).read() 437 | print(out) 438 | out = os.popen(lm_cmd % (count_file, lm_file)).read() 439 | print(out) 440 | 441 | # for kenlm 442 | kenlm_cmd = "lmplz -o 3 <%s >%s" 443 | for dataset_nm in opt.dataset_names: 444 | senti_captions = json.load( 445 | open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'senti_captions.json'), 'r')) 446 | for senti in senti_captions: 447 | senti_captions[senti] = [c[0] for c in senti_captions[senti]] 448 | idx2word = json.load(open(os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'idx2word.json'), 'r')) 449 | word2idx = {} 450 | for i, w in enumerate(idx2word): 451 | word2idx[w] = i 452 | 453 | senti_captions_id = {} 454 | for senti in senti_captions: 455 | senti_captions_id[senti] = [] 456 | for cap in senti_captions[senti]: 457 | tmp = [word2idx.get(w, None) or word2idx[''] for w in cap] + [word2idx['']] 458 | tmp = ' '.join([str(idx) for idx in tmp]) 459 | senti_captions_id[senti].append(tmp) 460 | lm_dir = os.path.join(opt.captions_dir, dataset_nm, corpus_type, 'lm') 461 | for senti in senti_captions_id: 462 | senti_captions_id[senti] = '\n'.join(senti_captions_id[senti]) 463 | with open(os.path.join(lm_dir, '%s_id.txt' % senti), 'w') as f: 464 | f.write(senti_captions_id[senti]) 465 | out = os.popen(kenlm_cmd % (os.path.join(lm_dir, '%s_id.txt' % senti), os.path.join(lm_dir, '%s_id.kenlm.arpa' % senti))).read() 466 | print(out) 467 | 468 | 469 | if __name__ == '__main__': 470 | parser = argparse.ArgumentParser() 471 | 472 | parser.add_argument('--imgs_dir', type=str, default='./data/images/sentiment') 473 | parser.add_argument('--feats_dir', type=str, default='./data/features/sentiment') 474 | parser.add_argument('--resnet101_file', type=str, 475 | default='./data/pre_models/resnet101.pth') 476 | 477 | parser.add_argument('--caption_datasets_dir', type=str, default='../../dataset/caption/caption_datasets') 478 | parser.add_argument('--dataset_names', type=list, default=['flickr30k', 'coco']) 479 | parser.add_argument('--captions_dir', type=str, default='./data/captions/') 480 | 481 | parser.add_argument('--corpus_dir', type=str, default='./data/corpus') 482 | 483 | parser.add_argument('--senti_imgs_dir', type=str, default='./data/images/sentiment') 484 | parser.add_argument('--img_senti_labels', type=str, default='./data/captions/img_senti_labels.json') 485 | 486 | opt = parser.parse_args() 487 | 488 | opt.use_gpu = torch.cuda.is_available() 489 | opt.device = torch.device('cuda:0') if opt.use_gpu else torch.device('cpu') 490 | 491 | try: 492 | # extract_imgs_feat() 493 | # process_coco_captions() 494 | process_senti_corpus() 495 | # build_idx2concept() 496 | # get_img_senti_labels() 497 | # build_idx2word() 498 | except BdbQuit: 499 | sys.exit(1) 500 | except Exception: 501 | traceback.print_exc() 502 | print('') 503 | pdb.post_mortem() 504 | sys.exit(1) 505 | -------------------------------------------------------------------------------- /self_critical/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/self_critical/__init__.py -------------------------------------------------------------------------------- /self_critical/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /self_critical/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /self_critical/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | from .bleu_scorer import BleuScorer 15 | 16 | 17 | class Bleu: 18 | def __init__(self, n=4): 19 | # default compute Blue score up to 4 20 | self._n = n 21 | self._hypo_for_image = {} 22 | self.ref_for_image = {} 23 | 24 | def compute_score(self, gts, res): 25 | 26 | # assert(list(gts.keys()) == list(res.keys())) 27 | # imgIds = list(gts.keys()) 28 | 29 | bleu_scorer = BleuScorer(n=self._n) 30 | # for id in imgIds: 31 | # hypo = res[id] 32 | # ref = gts[id] 33 | # 34 | # # Sanity check. 35 | # assert(type(hypo) is list) 36 | # assert(len(hypo) == 1) 37 | # assert(type(ref) is list) 38 | # assert(len(ref) >= 1) 39 | # 40 | # bleu_scorer += (hypo[0], ref) 41 | 42 | for res_id in res: 43 | hypo = res_id['caption'] 44 | ref = gts[res_id['image_id']] 45 | 46 | # Sanity check. 47 | assert(type(hypo) is list) 48 | assert(len(hypo) == 1) 49 | assert(type(ref) is list) 50 | assert(len(ref) >= 1) 51 | 52 | bleu_scorer += (hypo[0], ref) 53 | 54 | #score, scores = bleu_scorer.compute_score(option='shortest') 55 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 56 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 57 | 58 | # return (bleu, bleu_info) 59 | return score, scores 60 | 61 | def method(self): 62 | return "Bleu" 63 | -------------------------------------------------------------------------------- /self_critical/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import copy 23 | import sys, math, re 24 | from collections import defaultdict 25 | 26 | def precook(s, n=4, out=False): 27 | """Takes a string as input and returns an object that can be given to 28 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 29 | can take string arguments as well.""" 30 | words = s.split() 31 | counts = defaultdict(int) 32 | for k in range(1,n+1): 33 | for i in range(len(words)-k+1): 34 | ngram = tuple(words[i:i+k]) 35 | counts[ngram] += 1 36 | return (len(words), counts) 37 | 38 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 39 | '''Takes a list of reference sentences for a single segment 40 | and returns an object that encapsulates everything that BLEU 41 | needs to know about them.''' 42 | 43 | reflen = [] 44 | maxcounts = {} 45 | for ref in refs: 46 | rl, counts = precook(ref, n) 47 | reflen.append(rl) 48 | for (ngram,count) in counts.items(): 49 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 50 | 51 | # Calculate effective reference sentence length. 52 | if eff == "shortest": 53 | reflen = min(reflen) 54 | elif eff == "average": 55 | reflen = float(sum(reflen))/len(reflen) 56 | 57 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 58 | 59 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 60 | 61 | return (reflen, maxcounts) 62 | 63 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | (reflen, refmaxcounts) = xxx_todo_changeme 67 | testlen, counts = precook(test, n, True) 68 | 69 | result = {} 70 | 71 | # Calculate effective reference sentence length. 72 | 73 | if eff == "closest": 74 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 75 | else: ## i.e., "average" or "shortest" or None 76 | result["reflen"] = reflen 77 | 78 | result["testlen"] = testlen 79 | 80 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 81 | 82 | result['correct'] = [0]*n 83 | for (ngram, count) in counts.items(): 84 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 85 | 86 | return result 87 | 88 | class BleuScorer(object): 89 | """Bleu scorer. 90 | """ 91 | 92 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 93 | # special_reflen is used in oracle (proportional effective ref len for a node). 94 | 95 | def copy(self): 96 | ''' copy the refs.''' 97 | new = BleuScorer(n=self.n) 98 | new.ctest = copy.copy(self.ctest) 99 | new.crefs = copy.copy(self.crefs) 100 | new._score = None 101 | return new 102 | 103 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 104 | ''' singular instance ''' 105 | 106 | self.n = n 107 | self.crefs = [] 108 | self.ctest = [] 109 | self.cook_append(test, refs) 110 | self.special_reflen = special_reflen 111 | 112 | def cook_append(self, test, refs): 113 | '''called by constructor and __iadd__ to avoid creating new instances.''' 114 | 115 | if refs is not None: 116 | self.crefs.append(cook_refs(refs)) 117 | if test is not None: 118 | cooked_test = cook_test(test, self.crefs[-1]) 119 | self.ctest.append(cooked_test) ## N.B.: -1 120 | else: 121 | self.ctest.append(None) # lens of crefs and ctest have to match 122 | 123 | self._score = None ## need to recompute 124 | 125 | def ratio(self, option=None): 126 | self.compute_score(option=option) 127 | return self._ratio 128 | 129 | def score_ratio(self, option=None): 130 | '''return (bleu, len_ratio) pair''' 131 | return (self.fscore(option=option), self.ratio(option=option)) 132 | 133 | def score_ratio_str(self, option=None): 134 | return "%.4f (%.2f)" % self.score_ratio(option) 135 | 136 | def reflen(self, option=None): 137 | self.compute_score(option=option) 138 | return self._reflen 139 | 140 | def testlen(self, option=None): 141 | self.compute_score(option=option) 142 | return self._testlen 143 | 144 | def retest(self, new_test): 145 | if type(new_test) is str: 146 | new_test = [new_test] 147 | assert len(new_test) == len(self.crefs), new_test 148 | self.ctest = [] 149 | for t, rs in zip(new_test, self.crefs): 150 | self.ctest.append(cook_test(t, rs)) 151 | self._score = None 152 | 153 | return self 154 | 155 | def rescore(self, new_test): 156 | ''' replace test(s) with new test(s), and returns the new score.''' 157 | 158 | return self.retest(new_test).compute_score() 159 | 160 | def size(self): 161 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 162 | return len(self.crefs) 163 | 164 | def __iadd__(self, other): 165 | '''add an instance (e.g., from another sentence).''' 166 | 167 | if type(other) is tuple: 168 | ## avoid creating new BleuScorer instances 169 | self.cook_append(other[0], other[1]) 170 | else: 171 | assert self.compatible(other), "incompatible BLEUs." 172 | self.ctest.extend(other.ctest) 173 | self.crefs.extend(other.crefs) 174 | self._score = None ## need to recompute 175 | 176 | return self 177 | 178 | def compatible(self, other): 179 | return isinstance(other, BleuScorer) and self.n == other.n 180 | 181 | def single_reflen(self, option="average"): 182 | return self._single_reflen(self.crefs[0][0], option) 183 | 184 | def _single_reflen(self, reflens, option=None, testlen=None): 185 | 186 | if option == "shortest": 187 | reflen = min(reflens) 188 | elif option == "average": 189 | reflen = float(sum(reflens))/len(reflens) 190 | elif option == "closest": 191 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 192 | else: 193 | assert False, "unsupported reflen option %s" % option 194 | 195 | return reflen 196 | 197 | def recompute_score(self, option=None, verbose=0): 198 | self._score = None 199 | return self.compute_score(option, verbose) 200 | 201 | def compute_score(self, option=None, verbose=0): 202 | n = self.n 203 | small = 1e-9 204 | tiny = 1e-15 ## so that if guess is 0 still return 0 205 | bleu_list = [[] for _ in range(n)] 206 | 207 | if self._score is not None: 208 | return self._score 209 | 210 | if option is None: 211 | option = "average" if len(self.crefs) == 1 else "closest" 212 | 213 | self._testlen = 0 214 | self._reflen = 0 215 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 216 | 217 | # for each sentence 218 | for comps in self.ctest: 219 | testlen = comps['testlen'] 220 | self._testlen += testlen 221 | 222 | if self.special_reflen is None: ## need computation 223 | reflen = self._single_reflen(comps['reflen'], option, testlen) 224 | else: 225 | reflen = self.special_reflen 226 | 227 | self._reflen += reflen 228 | 229 | for key in ['guess','correct']: 230 | for k in range(n): 231 | totalcomps[key][k] += comps[key][k] 232 | 233 | # append per image bleu score 234 | bleu = 1. 235 | for k in range(n): 236 | bleu *= (float(comps['correct'][k]) + tiny) \ 237 | /(float(comps['guess'][k]) + small) 238 | bleu_list[k].append(bleu ** (1./(k+1))) 239 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 240 | if ratio < 1: 241 | for k in range(n): 242 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 243 | 244 | if verbose > 1: 245 | print(comps, reflen) 246 | 247 | totalcomps['reflen'] = self._reflen 248 | totalcomps['testlen'] = self._testlen 249 | 250 | bleus = [] 251 | bleu = 1. 252 | for k in range(n): 253 | bleu *= float(totalcomps['correct'][k] + tiny) \ 254 | / (totalcomps['guess'][k] + small) 255 | bleus.append(bleu ** (1./(k+1))) 256 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 257 | if ratio < 1: 258 | for k in range(n): 259 | bleus[k] *= math.exp(1 - 1/ratio) 260 | 261 | if verbose > 0: 262 | print(totalcomps) 263 | print("ratio:", ratio) 264 | 265 | self._score = bleus 266 | return self._score, bleu_list 267 | -------------------------------------------------------------------------------- /self_critical/cider/README.md: -------------------------------------------------------------------------------- 1 | Consensus-based Image Description Evaluation (CIDEr Code) 2 | =================== 3 | 4 | Change from [ruotianluo/cider](https://github.com/ruotianluo/cider/tree/dbb3960165d86202ed3c417b412a000fc8e717f3) 5 | -------------------------------------------------------------------------------- /self_critical/cider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezeli/InSentiCap_model/0e11ba1494633e83770d52805f513eab2339ddfe/self_critical/cider/__init__.py -------------------------------------------------------------------------------- /self_critical/cider/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /self_critical/cider/pyciderevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /self_critical/cider/pyciderevalcap/ciderD/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /self_critical/cider/pyciderevalcap/ciderD/ciderD.py: -------------------------------------------------------------------------------- 1 | # Filename: ciderD.py 2 | # 3 | # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from .ciderD_scorer import CiderScorer 14 | 15 | 16 | class CiderD: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, n=4, sigma=6.0, refs=None): 22 | self.cider_scorer = CiderScorer(n=n, sigma=sigma, refs=refs) 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | # clear all the previous hypos and refs 33 | self.cider_scorer.clear() 34 | for res_id in res: 35 | 36 | hypo = res_id['caption'] 37 | ref = gts[res_id['image_id']] 38 | 39 | # Sanity check. 40 | assert(type(hypo) is list) 41 | assert(len(hypo) == 1) 42 | assert(type(ref) is list) 43 | assert(len(ref) > 0) 44 | self.cider_scorer += (hypo[0], ref) 45 | 46 | (score, scores) = self.cider_scorer.compute_score() 47 | 48 | return score, scores 49 | 50 | def method(self): 51 | return "CIDEr-D" 52 | -------------------------------------------------------------------------------- /self_critical/cider/pyciderevalcap/ciderD/ciderD_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from collections import defaultdict 9 | import numpy as np 10 | import math 11 | 12 | 13 | def precook(s, n=4, out=False): 14 | """ 15 | Takes a string as input and returns an object that can be given to 16 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 17 | can take string arguments as well. 18 | :param s: string : sentence to be converted into ngrams 19 | :param n: int : number of ngrams for which representation is calculated 20 | :return: term frequency vector for occuring ngrams 21 | """ 22 | words = s.split() 23 | counts = defaultdict(int) 24 | for k in range(1,n+1): 25 | for i in range(len(words)-k+1): 26 | ngram = tuple(words[i:i+k]) 27 | counts[ngram] += 1 28 | return counts 29 | 30 | 31 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 32 | '''Takes a list of reference sentences for a single segment 33 | and returns an object that encapsulates everything that BLEU 34 | needs to know about them. 35 | :param refs: list of string : reference sentences for some image 36 | :param n: int : number of ngrams for which (ngram) representation is calculated 37 | :return: result (list of dict) 38 | ''' 39 | return [precook(ref, n) for ref in refs] 40 | 41 | 42 | def cook_test(test, n=4): 43 | '''Takes a test sentence and returns an object that 44 | encapsulates everything that BLEU needs to know about it. 45 | :param test: list of string : hypothesis sentence for some image 46 | :param n: int : number of ngrams for which (ngram) representation is calculated 47 | :return: result (dict) 48 | ''' 49 | return precook(test, n, True) 50 | 51 | 52 | def compute_doc_freq(crefs): 53 | ''' 54 | Compute term frequency for reference data. 55 | This will be used to compute idf (inverse document frequency later) 56 | The term frequency is stored in the object 57 | ''' 58 | document_frequency = defaultdict(float) 59 | for refs in crefs: 60 | # refs, k ref captions of one image 61 | for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]): 62 | document_frequency[ngram] += 1 63 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 64 | return document_frequency 65 | 66 | 67 | class CiderScorer(object): 68 | """CIDEr scorer. 69 | """ 70 | 71 | def __init__(self, n=4, sigma=6.0, refs=None): 72 | ''' singular instance ''' 73 | self.n = n 74 | self.sigma = sigma 75 | self.crefs = [] 76 | self.ctest = [] 77 | self.ref_len = None 78 | self.document_frequency = None 79 | if refs: 80 | self.update_df(refs) 81 | 82 | def update_df(self, refs): 83 | crefs = [] 84 | for ref in refs: 85 | # ref is a list of 5 captions 86 | crefs.append(cook_refs(ref)) 87 | self.document_frequency = compute_doc_freq(crefs) 88 | self.ref_len = np.log(float(len(refs))) 89 | 90 | def clear(self): 91 | self.crefs = [] 92 | self.ctest = [] 93 | 94 | def cook_append(self, test, refs): 95 | '''called by constructor and __iadd__ to avoid creating new instances.''' 96 | 97 | if refs is not None: 98 | self.crefs.append(cook_refs(refs)) 99 | if test is not None: 100 | self.ctest.append(cook_test(test)) ## N.B.: -1 101 | else: 102 | self.ctest.append(None) # lens of crefs and ctest have to match 103 | 104 | def size(self): 105 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 106 | return len(self.crefs) 107 | 108 | def __iadd__(self, other): 109 | '''add an instance (e.g., from another sentence).''' 110 | 111 | if type(other) is tuple: 112 | ## avoid creating new CiderScorer instances 113 | self.cook_append(other[0], other[1]) 114 | else: 115 | self.ctest.extend(other.ctest) 116 | self.crefs.extend(other.crefs) 117 | 118 | return self 119 | 120 | def compute_cider(self): 121 | def counts2vec(cnts): 122 | """ 123 | Function maps counts of ngram to vector of tfidf weights. 124 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 125 | The n-th entry of array denotes length of n-grams. 126 | :param cnts: 127 | :return: vec (array of dict), norm (array of float), length (int) 128 | """ 129 | vec = [defaultdict(float) for _ in range(self.n)] 130 | length = 0 131 | norm = [0.0 for _ in range(self.n)] 132 | for (ngram,term_freq) in cnts.items(): 133 | # give word count 1 if it doesn't appear in reference corpus 134 | df = np.log(max(1.0, self.document_frequency[ngram])) 135 | # ngram index 136 | n = len(ngram)-1 137 | # tf (term_freq) * idf (precomputed idf) for n-grams 138 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 139 | # compute norm for the vector. the norm will be used for computing similarity 140 | norm[n] += pow(vec[n][ngram], 2) 141 | 142 | if n == 1: 143 | length += term_freq 144 | norm = [np.sqrt(n) for n in norm] 145 | return vec, norm, length 146 | 147 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 148 | ''' 149 | Compute the cosine similarity of two vectors. 150 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 151 | :param vec_ref: array of dictionary for vector corresponding to reference 152 | :param norm_hyp: array of float for vector corresponding to hypothesis 153 | :param norm_ref: array of float for vector corresponding to reference 154 | :param length_hyp: int containing length of hypothesis 155 | :param length_ref: int containing length of reference 156 | :return: array of score for each n-grams cosine similarity 157 | ''' 158 | delta = float(length_hyp - length_ref) 159 | # measure consine similarity 160 | val = np.array([0.0 for _ in range(self.n)]) 161 | for n in range(self.n): 162 | # ngram 163 | for (ngram,count) in vec_hyp[n].items(): 164 | # vrama91 : added clipping 165 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 166 | 167 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 168 | val[n] /= (norm_hyp[n]*norm_ref[n]) 169 | 170 | assert(not math.isnan(val[n])) 171 | # vrama91: added a length based gaussian penalty 172 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 173 | return val 174 | 175 | scores = [] 176 | for test, refs in zip(self.ctest, self.crefs): 177 | # compute vector for test captions 178 | vec, norm, length = counts2vec(test) 179 | # compute vector for ref captions 180 | score = np.array([0.0 for _ in range(self.n)]) 181 | for ref in refs: 182 | vec_ref, norm_ref, length_ref = counts2vec(ref) 183 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 184 | # change by vrama91 - mean of ngram scores, instead of sum 185 | score_avg = np.mean(score) 186 | # divide by number of references 187 | score_avg /= len(refs) 188 | # multiply score by 10 189 | score_avg *= 10.0 190 | # append score of an image to the score list 191 | scores.append(score_avg) 192 | return scores 193 | 194 | def compute_score(self): 195 | # compute cider score 196 | score = self.compute_cider() 197 | return np.mean(np.array(score)), np.array(score) 198 | -------------------------------------------------------------------------------- /self_critical/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import tqdm 5 | from collections import defaultdict 6 | 7 | from .cider.pyciderevalcap.ciderD.ciderD import CiderD 8 | from .bleu.bleu import Bleu 9 | 10 | 11 | def _array_to_str(arr, sos_token, eos_token): 12 | arr = list(arr) 13 | if arr[0] == sos_token: 14 | arr = arr[1:] 15 | out = '' 16 | for i in range(len(arr)): 17 | if arr[i] == eos_token: 18 | break 19 | out += str(arr[i]) + ' ' 20 | out += str(eos_token) 21 | return out.strip() 22 | 23 | 24 | def _extract_feature(arr, sos_token, eos_token): 25 | arr = list(arr) 26 | if arr[0] == sos_token: 27 | arr = arr[1:] 28 | feature = {} 29 | for i in range(len(arr)): 30 | if arr[i] == eos_token: 31 | break 32 | feature[arr[i]] = True 33 | feature[eos_token] = True 34 | 35 | return feature 36 | 37 | 38 | def get_ciderd_scorer(split_captions, sos_token, eos_token): 39 | print('====> get_ciderd_scorer begin') 40 | captions = {} 41 | for caps in split_captions.values(): 42 | captions.update(caps) 43 | 44 | refs_idxs = [] 45 | for caps in tqdm.tqdm(captions.values()): 46 | ref_idxs = [] 47 | for cap in caps: 48 | ref_idxs.append(_array_to_str(cap, sos_token, eos_token)) 49 | refs_idxs.append(ref_idxs) 50 | 51 | scorer = CiderD(refs=refs_idxs) 52 | print('====> get_ciderd_scorer end') 53 | return scorer 54 | 55 | 56 | def get_self_critical_reward(sample_captions, greedy_captions, fns, ground_truth, 57 | sos_token, eos_token, scorer): 58 | batch_size = len(fns) 59 | sample_captions = sample_captions.cpu().numpy() 60 | greedy_captions = greedy_captions.cpu().numpy() 61 | assert sample_captions.shape[0] == greedy_captions.shape[0] == batch_size 62 | sample_result = [] 63 | greedy_result = [] 64 | gts = {} 65 | for i, fn in enumerate(fns): 66 | sample_result.append({'image_id': fn, 'caption': [_array_to_str(sample_captions[i], sos_token, eos_token)]}) 67 | greedy_result.append({'image_id': fn, 'caption': [_array_to_str(greedy_captions[i], sos_token, eos_token)]}) 68 | caps = [] 69 | for cap in ground_truth[fn]: 70 | caps.append(_array_to_str(cap, sos_token, eos_token)) 71 | gts[fn] = caps 72 | all_result = sample_result + greedy_result 73 | if isinstance(scorer, CiderD): 74 | _, scores = scorer.compute_score(gts, all_result) 75 | elif isinstance(scorer, Bleu): 76 | _, scores = scorer.compute_score(gts, all_result) 77 | scores = np.array(scores[3]) 78 | else: 79 | raise Exception('do not support this scorer: %s' % type(scorer)) 80 | 81 | scores = scores[:batch_size] - scores[batch_size:] 82 | rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 83 | return rewards 84 | 85 | 86 | def get_lm_reward(sample_captions, greedy_captions, senti_labels, sos_token, eos_token, lms): 87 | batch_size = sample_captions.size(0) 88 | sample_captions = sample_captions.cpu().numpy() 89 | greedy_captions = greedy_captions.cpu().numpy() 90 | senti_labels = senti_labels.cpu().numpy() 91 | scores = [] 92 | for i in range(batch_size): 93 | sample_res = _array_to_str(sample_captions[i], sos_token, eos_token) 94 | greedy_res = _array_to_str(greedy_captions[i], sos_token, eos_token) 95 | senti_lm = lms[senti_labels[i]] 96 | scores.append(np.sign(senti_lm.score(greedy_res) - senti_lm.score(sample_res))) 97 | # scores.append(senti_lm.perplexity(greedy_res) - senti_lm.perplexity(sample_res)) 98 | scores = np.array(scores) 99 | rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 100 | return rewards 101 | 102 | 103 | # def get_cls_reward(sample_captions, greedy_captions, senti_labels, sos_token, eos_token, sent_senti_cls): 104 | # batch_size = sample_captions.size(0) 105 | # sample_captions = sample_captions.cpu().numpy() 106 | # greedy_captions = greedy_captions.cpu().numpy() 107 | # senti_labels = senti_labels.cpu().numpy() 108 | # scores = [] 109 | # for i in range(batch_size): 110 | # sample_feat = _extract_feature(sample_captions[i], sos_token, eos_token) 111 | # greedy_feat = _extract_feature(greedy_captions[i], sos_token, eos_token) 112 | # prob_rests = sent_senti_cls.prob_classify_many([sample_feat, greedy_feat]) 113 | # scores.append(prob_rests[0]._prob_dict[senti_labels[i]] - 114 | # prob_rests[1]._prob_dict[senti_labels[i]]) 115 | # scores = np.array(scores) 116 | # rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 117 | # return rewards 118 | 119 | 120 | def get_cls_reward(sample_captions, sample_masks, greedy_captions, greedy_masks, senti_labels, sent_senti_cls): 121 | training = sent_senti_cls.training 122 | sample_lens = list(sample_masks.sum(dim=-1).type(torch.int).cpu().numpy()) 123 | # greedy_lens = list(greedy_masks.sum(dim=-1).type(torch.int).cpu().numpy()) 124 | sent_senti_cls.eval() 125 | with torch.no_grad(): 126 | sample_preds, sample_att_weights = sent_senti_cls(sample_captions, sample_lens) 127 | sample_preds = sample_preds.softmax(dim=-1) 128 | sample_preds = sample_preds.argmax(dim=-1) 129 | sample_preds = (sample_preds == senti_labels).type_as(sample_att_weights) 130 | sample_preds = sample_preds.unsqueeze(1) 131 | sample_scores = sample_preds * sample_att_weights 132 | sample_scores = sample_scores.detach().cpu().numpy() 133 | 134 | # greedy_preds, greedy_att_weights = sent_senti_cls(greedy_captions, greedy_lens) 135 | # greedy_preds = greedy_preds.softmax(dim=-1) 136 | # greedy_preds = greedy_preds.argmax(dim=-1) 137 | # greedy_preds = (greedy_preds == senti_labels).type_as(greedy_att_weights) 138 | # greedy_preds = greedy_preds.unsqueeze(1) 139 | # greedy_scores = greedy_preds * greedy_att_weights 140 | # greedy_scores = greedy_scores.detach().cpu().numpy() 141 | sent_senti_cls.train(training) 142 | 143 | max_len = sample_captions.shape[1] 144 | sample_scores = np.pad(sample_scores, ((0, 0), (0, max_len-sample_scores.shape[1]))) 145 | # greedy_scores = greedy_scores[:, :max_len] 146 | # greedy_scores = np.pad(greedy_scores, ((0, 0), (0, max_len - greedy_scores.shape[1]))) 147 | 148 | # rewards = sample_scores - greedy_scores 149 | # rewards = sample_scores - sample_scores.mean(-1).reshape(sample_scores.shape[0], 1) 150 | rewards = sample_scores 151 | return rewards 152 | 153 | 154 | def get_senti_words_reward(sample_captions, senti_labels, sentiment_words): 155 | batch_size = sample_captions.size(0) 156 | sample_captions = sample_captions.cpu().numpy() 157 | rewards = np.zeros(sample_captions.shape, dtype=float) 158 | accur_w = defaultdict(set) 159 | for i in range(batch_size): 160 | senti_id = int(senti_labels[i]) 161 | for j, w in enumerate(sample_captions[i]): 162 | if w in sentiment_words[senti_id]: 163 | rewards[i, j] = sentiment_words[senti_id][w] 164 | accur_w[senti_id].add(w) 165 | 166 | return rewards, accur_w 167 | 168 | 169 | class RewardCriterion(nn.Module): 170 | def __init__(self): 171 | super(RewardCriterion, self).__init__() 172 | 173 | def forward(self, seq_logprobs, seq_masks, reward): 174 | output = - seq_logprobs * seq_masks * reward 175 | output = torch.sum(output) / torch.sum(seq_masks) 176 | 177 | return output 178 | -------------------------------------------------------------------------------- /self_critical/utils_bac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import tqdm 5 | 6 | from .cider.pyciderevalcap.ciderD.ciderD import CiderD 7 | from .bleu.bleu import Bleu 8 | 9 | 10 | def _array_to_str(arr, sos_token, eos_token): 11 | arr = list(arr) 12 | if arr[0] == sos_token: 13 | arr = arr[1:] 14 | out = '' 15 | for i in range(len(arr)): 16 | if arr[i] == eos_token: 17 | break 18 | out += str(arr[i]) + ' ' 19 | out += str(eos_token) 20 | return out.strip() 21 | 22 | 23 | def _extract_feature(arr, sos_token, eos_token): 24 | arr = list(arr) 25 | if arr[0] == sos_token: 26 | arr = arr[1:] 27 | feature = {} 28 | for i in range(len(arr)): 29 | if arr[i] == eos_token: 30 | break 31 | feature[arr[i]] = True 32 | feature[eos_token] = True 33 | 34 | return feature 35 | 36 | 37 | def get_ciderd_scorer(split_captions, sos_token, eos_token): 38 | print('====> get_ciderd_scorer begin') 39 | captions = {} 40 | for caps in split_captions.values(): 41 | captions.update(caps) 42 | 43 | refs_idxs = [] 44 | for caps in tqdm.tqdm(captions.values()): 45 | ref_idxs = [] 46 | for cap in caps: 47 | ref_idxs.append(_array_to_str(cap, sos_token, eos_token)) 48 | refs_idxs.append(ref_idxs) 49 | 50 | scorer = CiderD(refs=refs_idxs) 51 | print('====> get_ciderd_scorer end') 52 | return scorer 53 | 54 | 55 | def get_self_critical_reward(sample_captions, greedy_captions, fns, ground_truth, 56 | sos_token, eos_token, scorer): 57 | batch_size = len(fns) 58 | sample_captions = sample_captions.cpu().numpy() 59 | greedy_captions = greedy_captions.cpu().numpy() 60 | assert sample_captions.shape[0] == greedy_captions.shape[0] == batch_size 61 | sample_result = [] 62 | greedy_result = [] 63 | gts = {} 64 | for i, fn in enumerate(fns): 65 | sample_result.append({'image_id': fn, 'caption': [_array_to_str(sample_captions[i], sos_token, eos_token)]}) 66 | greedy_result.append({'image_id': fn, 'caption': [_array_to_str(greedy_captions[i], sos_token, eos_token)]}) 67 | caps = [] 68 | for cap in ground_truth[fn]: 69 | caps.append(_array_to_str(cap, sos_token, eos_token)) 70 | gts[fn] = caps 71 | all_result = sample_result + greedy_result 72 | _, scores = scorer.compute_score(gts, all_result) 73 | if isinstance(scorer, CiderD): 74 | _, scores = scorer.compute_score(gts, all_result) 75 | elif isinstance(scorer, Bleu): 76 | _, scores = scorer.compute_score(gts, all_result) 77 | scores = np.array(scores[3]) 78 | else: 79 | raise Exception('do not support this scorer: %s' % type(scorer)) 80 | 81 | scores = scores[:batch_size] - scores[batch_size:] 82 | rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 83 | return rewards 84 | 85 | 86 | def get_lm_reward(sample_captions, greedy_captions, senti_labels, sos_token, eos_token, lms): 87 | batch_size = sample_captions.size(0) 88 | sample_captions = sample_captions.cpu().numpy() 89 | # greedy_captions = greedy_captions.cpu().numpy() 90 | senti_labels = senti_labels.cpu().numpy() 91 | scores = [] 92 | for i in range(batch_size): 93 | sample_res = _array_to_str(sample_captions[i], sos_token, eos_token) 94 | # greedy_res = _array_to_str(greedy_captions[i], sos_token, eos_token) 95 | senti_lm = lms[senti_labels[i]] 96 | # scores.append(np.sign(senti_lm.score(greedy_res) - senti_lm.score(sample_res))) 97 | scores.append(- senti_lm.score(sample_res)) 98 | scores = np.array(scores) 99 | scores = scores - scores.mean() 100 | rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 101 | return rewards 102 | 103 | 104 | # def get_cls_reward(sample_captions, greedy_captions, senti_labels, sos_token, eos_token, sent_senti_cls): 105 | # batch_size = sample_captions.size(0) 106 | # sample_captions = sample_captions.cpu().numpy() 107 | # greedy_captions = greedy_captions.cpu().numpy() 108 | # senti_labels = senti_labels.cpu().numpy() 109 | # scores = [] 110 | # for i in range(batch_size): 111 | # sample_feat = _extract_feature(sample_captions[i], sos_token, eos_token) 112 | # greedy_feat = _extract_feature(greedy_captions[i], sos_token, eos_token) 113 | # prob_rests = sent_senti_cls.prob_classify_many([sample_feat, greedy_feat]) 114 | # scores.append(prob_rests[0]._prob_dict[senti_labels[i]] - 115 | # prob_rests[1]._prob_dict[senti_labels[i]]) 116 | # scores = np.array(scores) 117 | # rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 118 | # return rewards 119 | 120 | 121 | def get_cls_reward(sample_captions, sample_masks, greedy_captions, greedy_masks, senti_labels, sent_senti_cls): 122 | batch_size = sample_captions.size(0) 123 | sample_lens = list(sample_masks.sum(dim=-1).type(torch.int).cpu().numpy()) 124 | # greedy_lens = list(greedy_masks.sum(dim=-1).type(torch.int).cpu().numpy()) 125 | with torch.no_grad(): 126 | sample_preds = sent_senti_cls(sample_captions, sample_lens) 127 | # greedy_preds = sent_senti_cls(greedy_captions, greedy_lens) 128 | 129 | scores = [] 130 | for i in range(batch_size): 131 | senti_id = senti_labels[i] 132 | # scores.append(float(sample_preds[i][senti_id] - greedy_preds[i][senti_id])) 133 | scores.append(float(sample_preds[i][senti_id])) 134 | scores = np.array(scores) 135 | scores = scores - scores.mean() 136 | rewards = np.repeat(scores[:, np.newaxis], sample_captions.shape[1], 1) 137 | return rewards 138 | 139 | 140 | def get_senti_words_reward(sample_captions, senti_words): 141 | batch_size = sample_captions.size(0) 142 | sample_captions = sample_captions.cpu().numpy() 143 | senti_words = senti_words.cpu().numpy() 144 | rewards = np.zeros(sample_captions.shape, dtype=float) 145 | for i in range(batch_size): 146 | for j, w in enumerate(sample_captions[i]): 147 | if w in senti_words[i]: 148 | rewards[i, j] = 1 149 | return rewards 150 | 151 | 152 | class RewardCriterion(nn.Module): 153 | def __init__(self): 154 | super(RewardCriterion, self).__init__() 155 | 156 | def forward(self, seq_logprobs, seq_masks, reward): 157 | output = - seq_logprobs * seq_masks * reward 158 | output = torch.sum(output) / torch.sum(seq_masks) 159 | 160 | return output 161 | -------------------------------------------------------------------------------- /test_cpt.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | import h5py 4 | import json 5 | import os 6 | 7 | from opts import parse_opt 8 | from models.concept_detector import ConceptDetector 9 | 10 | opt = parse_opt() 11 | print("====> loading checkpoint '{}'".format(opt.test_model)) 12 | chkpoint = torch.load(opt.test_model, map_location=lambda s, l: s) 13 | idx2concept = chkpoint['idx2concept'] 14 | settings = chkpoint['settings'] 15 | dataset_name = chkpoint['dataset_name'] 16 | model = ConceptDetector(idx2concept, settings) 17 | model.to(opt.device) 18 | model.load_state_dict(chkpoint['model']) 19 | model.eval() 20 | print("====> loaded checkpoint '{}', epoch: {}, dataset_name: {}". 21 | format(opt.test_model, chkpoint['epoch'], dataset_name)) 22 | 23 | img_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_concepts.json'), 'r')) 24 | f_fc = h5py.File(os.path.join(opt.feats_dir, dataset_name, '%s_fc.h5' % dataset_name), 'r') 25 | test_img = opt.image_file or list(img_concepts['test'].keys())[0] 26 | feat = torch.FloatTensor(f_fc[test_img][:]) 27 | feat = feat.to(opt.device) 28 | feat = feat.unsqueeze(0) 29 | _, concepts, scores = model.sample(feat, num=opt.num_concepts) 30 | concepts = concepts[0] 31 | scores = scores[0] 32 | 33 | print('test_img: ', test_img) 34 | print('concepts: ', concepts) 35 | print('scores: ', scores) 36 | print('ground truth: ', img_concepts['test'][test_img]) 37 | 38 | wrong = [] 39 | for c in concepts: 40 | if c not in img_concepts['test'][test_img]: 41 | wrong.append(c) 42 | print('\nwrong rate:', len(wrong) / len(concepts)) 43 | print('wrong concepts:', wrong) 44 | -------------------------------------------------------------------------------- /train_cpt.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import tqdm 3 | import os 4 | from copy import deepcopy 5 | import time 6 | import json 7 | import sys 8 | import pdb 9 | import traceback 10 | from bdb import BdbQuit 11 | import torch 12 | 13 | from opts import parse_opt 14 | from models.concept_detector import ConceptDetector 15 | from dataloader import get_concept_dataloader 16 | 17 | 18 | def clip_gradient(optimizer, grad_clip): 19 | for group in optimizer.param_groups: 20 | for param in group['params']: 21 | param.grad.data.clamp_(-grad_clip, grad_clip) 22 | 23 | 24 | def train(): 25 | dataset_name = opt.dataset_name 26 | 27 | idx2concept = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'idx2concept.json'), 'r')) 28 | img_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_concepts.json'), 'r')) 29 | 30 | cpt_detector = ConceptDetector(idx2concept, opt.settings) 31 | cpt_detector.to(opt.device) 32 | lr = opt.concept_lr 33 | optimizer, criterion = cpt_detector.get_optim_criterion(lr) 34 | if opt.concept_resume: 35 | print("====> loading checkpoint '{}'".format(opt.concept_resume)) 36 | chkpoint = torch.load(opt.concept_resume, map_location=lambda s, l: s) 37 | assert opt.settings == chkpoint['settings'], \ 38 | 'opt.settings and resume model settings are different' 39 | assert idx2concept == chkpoint['idx2concept'], \ 40 | 'idx2concept and resume model idx2concept are different' 41 | assert dataset_name == chkpoint['dataset_name'], \ 42 | 'dataset_name and resume model dataset_name are different' 43 | cpt_detector.load_state_dict(chkpoint['model']) 44 | optimizer.load_state_dict(chkpoint['optimizer']) 45 | lr = optimizer.param_groups[0]['lr'] 46 | print("====> loaded checkpoint '{}', epoch: {}" 47 | .format(opt.concept_resume, chkpoint['epoch'])) 48 | 49 | concept2idx = {} 50 | for i, w in enumerate(idx2concept): 51 | concept2idx[w] = i 52 | 53 | ground_truth = deepcopy(img_concepts['test']) 54 | print('====> process image concepts begin') 55 | img_concepts_id = {} 56 | for split, concepts in img_concepts.items(): 57 | print('convert %s concepts to index' % split) 58 | img_concepts_id[split] = {} 59 | for fn, cpts in tqdm.tqdm(concepts.items()): 60 | cpts = [concept2idx[c] for c in cpts if c in concept2idx] 61 | img_concepts_id[split][fn] = cpts 62 | img_concepts = img_concepts_id 63 | print('====> process image concepts end') 64 | 65 | f_fc = os.path.join(opt.feats_dir, dataset_name, '%s_fc.h5' % dataset_name) 66 | train_data = get_concept_dataloader( 67 | f_fc, img_concepts['train'], len(idx2concept), 68 | opt.concept_bs, opt.concept_num_works) 69 | val_data = get_concept_dataloader( 70 | f_fc, img_concepts['val'], len(idx2concept), 71 | opt.concept_bs, opt.concept_num_works, shuffle=False) 72 | test_data = get_concept_dataloader( 73 | f_fc, img_concepts['test'], len(idx2concept), 74 | opt.concept_bs, opt.concept_num_works, shuffle=False) 75 | 76 | def forward(data, training=True): 77 | cpt_detector.train(training) 78 | loss_val = 0.0 79 | for _, fc_feats, cpts_tensors in tqdm.tqdm(data): 80 | fc_feats = fc_feats.to(opt.device) 81 | cpts_tensors = cpts_tensors.to(opt.device) 82 | pred = cpt_detector(fc_feats) 83 | loss = criterion(pred, cpts_tensors) 84 | loss_val += loss.item() 85 | if training: 86 | optimizer.zero_grad() 87 | loss.backward() 88 | clip_gradient(optimizer, opt.grad_clip) 89 | optimizer.step() 90 | return loss_val / len(data) 91 | 92 | checkpoint = os.path.join(opt.checkpoint, 'concept', dataset_name) 93 | if not os.path.exists(checkpoint): 94 | os.makedirs(checkpoint) 95 | previous_loss = None 96 | for epoch in range(opt.concept_epochs): 97 | print('--------------------epoch: %d' % epoch) 98 | train_loss = forward(train_data) 99 | with torch.no_grad(): 100 | val_loss = forward(val_data, training=False) 101 | 102 | # test 103 | test_loss = 0.0 104 | pre = 0.0 105 | recall = 0.0 106 | last_score = 0.0 107 | for fns, fc_feats, cpts_tensors in tqdm.tqdm(test_data): 108 | fc_feats = fc_feats.to(opt.device) 109 | cpts_tensors = cpts_tensors.to(opt.device) 110 | pred, concepts, scores = cpt_detector.sample(fc_feats, num=opt.num_concepts) 111 | loss = criterion(pred, cpts_tensors) 112 | test_loss += loss.item() 113 | tmp_pre = 0.0 114 | tmp_rec = 0.0 115 | for i, fn in enumerate(fns): 116 | cpts = concepts[i] 117 | grdt = ground_truth[fn] 118 | jiaoji = len(set(grdt) - (set(grdt) - set(cpts))) 119 | tmp_pre += jiaoji / len(cpts) 120 | tmp_rec += jiaoji / len(grdt) 121 | pre += tmp_pre / len(fns) 122 | recall += tmp_rec / len(fns) 123 | last_score += float(scores[:, -1].mean()) 124 | data_len = len(test_data) 125 | test_loss = test_loss / data_len 126 | pre = pre / data_len 127 | recall = recall / data_len 128 | last_score = last_score / data_len 129 | 130 | if previous_loss is not None and val_loss > previous_loss: 131 | lr = lr * 0.5 132 | for param_group in optimizer.param_groups: 133 | param_group['lr'] = lr 134 | previous_loss = val_loss 135 | 136 | print('train_loss: %.4f, val_loss: %.4f, test_loss: %.4f, ' 137 | 'precision: %.4f, recall: %.4f, last_score: %.4f' % 138 | (train_loss, val_loss, test_loss, pre, recall, last_score)) 139 | if epoch > -1: 140 | chkpoint = { 141 | 'epoch': epoch, 142 | 'model': cpt_detector.state_dict(), 143 | 'optimizer': optimizer.state_dict(), 144 | 'settings': opt.settings, 145 | 'idx2concept': idx2concept, 146 | 'dataset_name': dataset_name, 147 | } 148 | checkpoint_path = os.path.join(checkpoint, 'model_%d_%.4f_%.4f_%s.pth' % ( 149 | epoch, train_loss, val_loss, time.strftime('%m%d-%H%M'))) 150 | torch.save(chkpoint, checkpoint_path) 151 | 152 | 153 | if __name__ == '__main__': 154 | try: 155 | opt = parse_opt() 156 | train() 157 | except BdbQuit: 158 | sys.exit(1) 159 | except Exception: 160 | traceback.print_exc() 161 | print('') 162 | pdb.post_mortem() 163 | sys.exit(1) 164 | -------------------------------------------------------------------------------- /train_rl.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import tqdm 3 | import os 4 | import time 5 | from collections import defaultdict 6 | import json 7 | import sys 8 | import pdb 9 | import traceback 10 | from bdb import BdbQuit 11 | import torch 12 | import random 13 | 14 | from opts import parse_opt 15 | from models.decoder import Detector 16 | from dataloader import get_rl_fact_dataloader, get_rl_senti_dataloader, get_senti_corpus_with_sentis_dataloader 17 | 18 | 19 | def clip_gradient(optimizer, grad_clip): 20 | for group in optimizer.param_groups: 21 | for param in group['params']: 22 | param.grad.data.clamp_(-grad_clip, grad_clip) 23 | 24 | 25 | def train(): 26 | dataset_name = opt.dataset_name 27 | corpus_type = opt.corpus_type 28 | 29 | idx2word = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'idx2word.json'), 'r')) 30 | img_captions = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_captions.json'), 'r')) 31 | img_det_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_det_concepts.json'), 'r')) 32 | img_det_sentiments = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'img_det_sentiments.json'), 'r')) 33 | img_senti_labels = json.load(open(opt.img_senti_labels, 'r')) 34 | senti_captions = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'senti_captions.json'), 'r')) 35 | sentiment_words = json.load(open(os.path.join(opt.corpus_dir, corpus_type, 'sentiment_words.json'), 'r')) 36 | 37 | model = Detector(idx2word, opt.max_seq_len, opt.sentiment_categories, opt.rl_lrs, opt.settings) 38 | model.to(opt.device) 39 | if opt.rl_resume: 40 | print("====> loading checkpoint '{}'".format(opt.rl_resume)) 41 | chkpoint = torch.load(opt.rl_resume, map_location=lambda s, l: s) 42 | assert opt.settings == chkpoint['settings'], \ 43 | 'opt.settings and resume model settings are different' 44 | assert idx2word == chkpoint['idx2word'], \ 45 | 'idx2word and resume model idx2word are different' 46 | assert opt.max_seq_len == chkpoint['max_seq_len'], \ 47 | 'opt.max_seq_len and resume model max_seq_len are different' 48 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 49 | 'opt.sentiment_categories and resume model sentiment_categories are different' 50 | assert dataset_name == chkpoint['dataset_name'], \ 51 | 'dataset_name and resume model dataset_name are different' 52 | assert corpus_type == chkpoint['corpus_type'], \ 53 | 'corpus_type and resume model corpus_type are different' 54 | model.load_state_dict(chkpoint['model']) 55 | print("====> loaded checkpoint '{}', epoch: {}" 56 | .format(opt.rl_resume, chkpoint['epoch'])) 57 | else: 58 | rl_xe_resume = os.path.join(opt.checkpoint, 'xe', dataset_name, corpus_type, 'model-best.pth') 59 | print("====> loading checkpoint '{}'".format(rl_xe_resume)) 60 | chkpoint = torch.load(rl_xe_resume, map_location=lambda s, l: s) 61 | assert opt.settings == chkpoint['settings'], \ 62 | 'opt.settings and resume model settings are different' 63 | assert idx2word == chkpoint['idx2word'], \ 64 | 'idx2word and resume model idx2word are different' 65 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 66 | 'opt.sentiment_categories and resume model sentiment_categories are different' 67 | assert dataset_name == chkpoint['dataset_name'], \ 68 | 'dataset_name and resume model dataset_name are different' 69 | assert corpus_type == chkpoint['corpus_type'], \ 70 | 'corpus_type and resume model corpus_type are different' 71 | model.captioner.load_state_dict(chkpoint['model']) 72 | print("====> loaded checkpoint '{}', epoch: {}" 73 | .format(rl_xe_resume, chkpoint['epoch'])) 74 | 75 | if opt.rl_senti_resume: 76 | print("====> loading rl_senti_resume '{}'".format(opt.rl_senti_resume)) 77 | ch = torch.load(opt.rl_senti_resume, map_location=lambda s, l: s) 78 | assert opt.settings == ch['settings'], \ 79 | 'opt.settings and rl_senti_resume settings are different' 80 | assert opt.sentiment_categories == ch['sentiment_categories'], \ 81 | 'opt.sentiment_categories and rl_senti_resume sentiment_categories are different' 82 | model.senti_detector.load_state_dict(ch['model']) 83 | 84 | if True: 85 | ss_cls_file = os.path.join(opt.checkpoint, 'sent_senti_cls', dataset_name, corpus_type, 'model-best.pth') 86 | print("====> loading checkpoint '{}'".format(ss_cls_file)) 87 | chkpoint = torch.load(ss_cls_file, map_location=lambda s, l: s) 88 | assert opt.settings == chkpoint['settings'], \ 89 | 'opt.settings and resume model settings are different' 90 | assert idx2word == chkpoint['idx2word'], \ 91 | 'idx2word and resume model idx2word are different' 92 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 93 | 'opt.sentiment_categories and resume model sentiment_categories are different' 94 | assert dataset_name == chkpoint['dataset_name'], \ 95 | 'dataset_name and resume model dataset_name are different' 96 | assert corpus_type == chkpoint['corpus_type'], \ 97 | 'corpus_type and resume model corpus_type are different' 98 | model.sent_senti_cls.load_state_dict(chkpoint['model']) 99 | 100 | word2idx = {} 101 | for i, w in enumerate(idx2word): 102 | word2idx[w] = i 103 | 104 | print('====> process image captions begin') 105 | captions_id = {} 106 | for split, caps in img_captions.items(): 107 | print('convert %s captions to index' % split) 108 | captions_id[split] = {} 109 | for fn, seqs in tqdm.tqdm(caps.items()): 110 | tmp = [] 111 | for seq in seqs: 112 | tmp.append([model.captioner.sos_id] + 113 | [word2idx.get(w, None) or word2idx[''] for w in seq] + 114 | [model.captioner.eos_id]) 115 | captions_id[split][fn] = tmp 116 | img_captions = captions_id 117 | print('====> process image captions end') 118 | 119 | print('====> process image det_concepts begin') 120 | det_concepts_id = {} 121 | for fn, cpts in tqdm.tqdm(img_det_concepts.items()): 122 | det_concepts_id[fn] = [word2idx[w] for w in cpts] 123 | img_det_concepts = det_concepts_id 124 | print('====> process image det_concepts end') 125 | 126 | print('====> process image det_sentiments begin') 127 | det_sentiments_id = {} 128 | for fn, sentis in tqdm.tqdm(img_det_sentiments.items()): 129 | det_sentiments_id[fn] = [word2idx[w] for w in sentis] 130 | img_det_sentiments = det_sentiments_id 131 | print('====> process image det_concepts end') 132 | 133 | senti_label2idx = {} 134 | for i, w in enumerate(opt.sentiment_categories): 135 | senti_label2idx[w] = i 136 | print('====> process image senti_labels begin') 137 | senti_labels_id = {} 138 | for split, senti_labels in img_senti_labels.items(): 139 | print('convert %s senti_labels to index' % split) 140 | senti_labels_id[split] = [] 141 | for fn, senti_label in tqdm.tqdm(senti_labels): 142 | senti_labels_id[split].append([fn, senti_label2idx[senti_label]]) 143 | img_senti_labels = senti_labels_id 144 | print('====> process image senti_labels end') 145 | 146 | print('====> process senti corpus begin') 147 | senti_captions['positive'] = senti_captions['positive'] * int(len(senti_captions['neutral']) / len(senti_captions['positive'])) 148 | senti_captions['negative'] = senti_captions['negative'] * int(len(senti_captions['neutral']) / len(senti_captions['negative'])) 149 | senti_captions_id = [] 150 | for senti, caps in senti_captions.items(): 151 | print('convert %s corpus to index' % senti) 152 | senti_id = senti_label2idx[senti] 153 | for cap, cpts, sentis in tqdm.tqdm(caps): 154 | cap = [model.captioner.sos_id] +\ 155 | [word2idx.get(w, None) or word2idx[''] for w in cap] +\ 156 | [model.captioner.eos_id] 157 | cpts = [word2idx[w] for w in cpts if w in word2idx] 158 | sentis = [word2idx[w] for w in sentis] 159 | senti_captions_id.append([cap, cpts, sentis, senti_id]) 160 | random.shuffle(senti_captions_id) 161 | senti_captions = senti_captions_id 162 | print('====> process senti corpus end') 163 | 164 | print('====> process sentiment words begin') 165 | tmp_sentiment_words = {} 166 | for senti in opt.sentiment_categories: 167 | senti_id = senti_label2idx[senti] 168 | if senti not in sentiment_words: 169 | tmp_sentiment_words[senti_id] = dict() 170 | else: 171 | tmp_sentiment_words[senti_id] = {word2idx[w]: 1.0 for w, s in sentiment_words[senti].items()} 172 | sentiment_words = tmp_sentiment_words 173 | print('====> process sentiment words end') 174 | 175 | fc_feats = os.path.join(opt.feats_dir, dataset_name, '%s_fc.h5' % dataset_name) 176 | att_feats = os.path.join(opt.feats_dir, dataset_name, '%s_att.h5' % dataset_name) 177 | fact_train_data = get_rl_fact_dataloader( 178 | fc_feats, att_feats, img_captions['train'], img_det_concepts, 179 | img_det_sentiments, model.captioner.pad_id, opt.max_seq_len, 180 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works) 181 | fact_val_data = get_rl_fact_dataloader( 182 | fc_feats, att_feats, img_captions['val'], img_det_concepts, 183 | img_det_sentiments, model.captioner.pad_id, opt.max_seq_len, 184 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works, shuffle=False) 185 | test_captions = {} 186 | for fn in img_captions['test']: 187 | test_captions[fn] = [[]] 188 | fact_test_data = get_rl_fact_dataloader( 189 | fc_feats, att_feats, test_captions, img_det_concepts, 190 | img_det_sentiments, model.captioner.pad_id, opt.max_seq_len, 191 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works, shuffle=False) 192 | 193 | senti_fc_feats = os.path.join(opt.feats_dir, 'sentiment', 'feats_fc.h5') 194 | senti_att_feats = os.path.join(opt.feats_dir, 'sentiment', 'feats_att.h5') 195 | senti_train_data = get_rl_senti_dataloader( 196 | senti_fc_feats, senti_att_feats, img_det_concepts, 197 | img_det_sentiments, img_senti_labels['train'], model.captioner.pad_id, 198 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works) 199 | senti_val_data = get_rl_senti_dataloader( 200 | senti_fc_feats, senti_att_feats, img_det_concepts, 201 | img_det_sentiments, img_senti_labels['val'], model.captioner.pad_id, 202 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works, shuffle=False) 203 | senti_test_data = get_rl_senti_dataloader( 204 | senti_fc_feats, senti_att_feats, img_det_concepts, 205 | img_det_sentiments, img_senti_labels['test'], model.captioner.pad_id, 206 | opt.num_concepts, opt.num_sentiments, opt.rl_bs, opt.rl_num_works, shuffle=False) 207 | 208 | scs_data = get_senti_corpus_with_sentis_dataloader( 209 | senti_captions, idx2word.index(''), opt.max_seq_len, 210 | opt.num_concepts, opt.num_sentiments, 80, opt.rl_num_works) 211 | 212 | # lms = {} 213 | # lm_dir = os.path.join(opt.captions_dir, dataset_name, corpus_type, 'lm') 214 | # for senti, i in senti_label2idx.items(): 215 | # lms[i] = kenlm.LanguageModel(os.path.join(lm_dir, '%s_id.kenlm.arpa' % senti)) 216 | # model.set_lms(lms) 217 | 218 | model.set_ciderd_scorer(img_captions) 219 | model.set_sentiment_words(sentiment_words) 220 | 221 | tmp_dir = '' 222 | checkpoint = os.path.join(opt.checkpoint, 'rl', dataset_name, corpus_type, tmp_dir) 223 | if not os.path.exists(checkpoint): 224 | os.makedirs(checkpoint) 225 | result_dir = os.path.join(opt.result_dir, 'rl', dataset_name, corpus_type, tmp_dir) 226 | if not os.path.exists(result_dir): 227 | os.makedirs(result_dir) 228 | for epoch in range(opt.rl_epochs): 229 | print('--------------------epoch: %d' % epoch) 230 | print('tmp_dir:', tmp_dir, 'cls_flag:', model.cls_flag, 'seq_flag:', model.seq_flag) 231 | torch.cuda.empty_cache() 232 | for i in range(opt.rl_senti_times): 233 | print('----------rl_senti_times: %d' % i) 234 | senti_train_loss = model((senti_train_data, scs_data), data_type='senti', training=True) 235 | print('senti_train_loss: %s' % dict(senti_train_loss)) 236 | for i in range(opt.rl_fact_times): 237 | # seq 在前面比较好 238 | # print('----------seq2seq train') 239 | # seq2seq_train_loss = model(scs_data, data_type='seq2seq', training=True) 240 | # print('seq2seq_train_loss: %s' % seq2seq_train_loss) 241 | print('----------rl_fact_times: %d' % i) 242 | fact_train_loss = model((fact_train_data, scs_data), data_type='fact', training=True) 243 | print('fact_train_loss: %s' % dict(fact_train_loss)) 244 | 245 | with torch.no_grad(): 246 | torch.cuda.empty_cache() 247 | print('----------val') 248 | fact_val_loss = model((fact_val_data,), data_type='fact', training=False) 249 | print('fact_val_loss:', dict(fact_val_loss)) 250 | 251 | # test 252 | results = {'fact': defaultdict(list), 'senti': defaultdict(list)} 253 | det_sentis = defaultdict(dict) 254 | senti_imgs_num = 0 255 | senti_imgs_wrong_num = 0 256 | for data_type, data in [('fact', fact_test_data), ('senti', senti_test_data)]: 257 | print('----------test:', data_type) 258 | for data_item in tqdm.tqdm(data): 259 | if data_type == 'fact': 260 | fns, fc_feats, att_feats, _, cpts_tensor, sentis_tensor, ground_truth = data_item 261 | elif data_type == 'senti': 262 | fns, fc_feats, att_feats, cpts_tensor, sentis_tensor, senti_labels = data_item 263 | senti_labels = senti_labels.to(opt.device) 264 | senti_labels = [opt.sentiment_categories[int(idx)] for idx in senti_labels] 265 | else: 266 | raise Exception('data_type(%s) is wrong!' % data_type) 267 | fc_feats = fc_feats.to(opt.device) 268 | att_feats = att_feats.to(opt.device) 269 | sentis_tensor = sentis_tensor.to(opt.device) 270 | 271 | for i, fn in enumerate(fns): 272 | captions, det_img_sentis = model.sample( 273 | fc_feats[i], att_feats[i], sentis_tensor[i], beam_size=opt.beam_size) 274 | results[data_type][det_img_sentis[0]].append({'image_id': fn, 'caption': captions[0]}) 275 | det_sentis[data_type][fn] = det_img_sentis[0] 276 | if data_type == 'senti': 277 | senti_imgs_num += 1 278 | if det_img_sentis[0] != senti_labels[i]: 279 | senti_imgs_wrong_num += 1 280 | 281 | det_sentis_wrong_rate = senti_imgs_wrong_num / senti_imgs_num 282 | 283 | for data_type in results: 284 | for senti in results[data_type]: 285 | json.dump(results[data_type][senti], 286 | open(os.path.join(result_dir, 'result_%d_%s_%s.json' % (epoch, senti, data_type)), 'w')) 287 | wr = det_sentis_wrong_rate 288 | if data_type == 'fact': 289 | wr = 0 290 | json.dump(det_sentis[data_type], 291 | open(os.path.join(result_dir, 'result_%d_sentis_%s_%s.json' % (epoch, wr, data_type)), 'w')) 292 | 293 | sents = {'fact': defaultdict(str), 'senti': defaultdict(str)} 294 | sents_w = {'fact': defaultdict(str), 'senti': defaultdict(str)} 295 | for data_type in results: 296 | for senti in results[data_type]: 297 | ress = results[data_type][senti] 298 | for res in ress: 299 | caption = res['caption'] 300 | sents_w[data_type][senti] += caption + '\n' 301 | caption = [str(word2idx[w]) for w in caption.split()] + [str(word2idx[''])] 302 | caption = ' '.join(caption) + '\n' 303 | sents[data_type][senti] += caption 304 | for data_type in sents: 305 | for senti in sents[data_type]: 306 | with open(os.path.join(result_dir, 'result_%d_%s_%s.txt' % (epoch, senti, data_type)), 'w') as f: 307 | f.write(sents[data_type][senti]) 308 | with open(os.path.join(result_dir, 'result_%d_%s_%s_w.txt' % (epoch, senti, data_type)), 'w') as f: 309 | f.write(sents_w[data_type][senti]) 310 | 311 | if epoch > -1: 312 | chkpoint = { 313 | 'epoch': epoch, 314 | 'model': model.state_dict(), 315 | 'settings': opt.settings, 316 | 'idx2word': idx2word, 317 | 'max_seq_len': opt.max_seq_len, 318 | 'sentiment_categories': opt.sentiment_categories, 319 | 'dataset_name': dataset_name, 320 | 'corpus_type': corpus_type, 321 | } 322 | checkpoint_path = os.path.join( 323 | checkpoint, 'model_%d_%s.pth' % ( 324 | epoch, time.strftime('%m%d-%H%M'))) 325 | torch.save(chkpoint, checkpoint_path) 326 | 327 | 328 | if __name__ == '__main__': 329 | try: 330 | opt = parse_opt() 331 | train() 332 | except (BdbQuit, torch.cuda.memory_allocated()): 333 | sys.exit(1) 334 | except Exception: 335 | traceback.print_exc() 336 | print('') 337 | pdb.post_mortem() 338 | sys.exit(1) 339 | -------------------------------------------------------------------------------- /train_sent_senti_cls_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import time 5 | import sys 6 | import pdb 7 | import traceback 8 | from bdb import BdbQuit 9 | import json 10 | import random 11 | import tqdm 12 | from copy import deepcopy 13 | from collections import defaultdict 14 | 15 | from opts import parse_opt 16 | from models.sent_senti_cls import SentenceSentimentClassifier 17 | from dataloader import get_senti_sents_dataloader 18 | 19 | 20 | random.seed(100) 21 | resume = '' 22 | 23 | 24 | def clip_gradient(optimizer, grad_clip): 25 | for group in optimizer.param_groups: 26 | for param in group['params']: 27 | if param.grad is not None: 28 | param.grad.data.clamp_(-grad_clip, grad_clip) 29 | 30 | 31 | def train(): 32 | dataset_name = opt.dataset_name 33 | corpus_type = opt.corpus_type 34 | 35 | idx2word = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'idx2word.json'), 'r')) 36 | senti_captions = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'senti_captions.json'), 'r')) 37 | 38 | model = SentenceSentimentClassifier(idx2word, opt.sentiment_categories, opt.settings) 39 | model.to(opt.device) 40 | lr = 4e-4 41 | optimizer, criterion = model.get_optim_and_crit(lr) 42 | if resume: 43 | print("====> loading checkpoint '{}'".format(resume)) 44 | chkpoint = torch.load(resume, map_location=lambda s, l: s) 45 | assert opt.settings == chkpoint['settings'], \ 46 | 'opt.settings and resume model settings are different' 47 | assert idx2word == chkpoint['idx2word'], \ 48 | 'idx2word and resume model idx2word are different' 49 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 50 | 'sentiment_categories and resume model sentiment_categories are different' 51 | assert dataset_name == chkpoint['dataset_name'], \ 52 | 'dataset_name and resume model dataset_name are different' 53 | assert corpus_type == chkpoint['corpus_type'], \ 54 | 'corpus_type and resume model corpus_type are different' 55 | model.load_state_dict(chkpoint['model']) 56 | optimizer.load_state_dict(chkpoint['optimizer']) 57 | lr = optimizer.param_groups[0]['lr'] 58 | print("====> loaded checkpoint '{}', epoch: {}" 59 | .format(resume, chkpoint['epoch'])) 60 | 61 | word2idx = {} 62 | for i, w in enumerate(idx2word): 63 | word2idx[w] = i 64 | senti_label2idx = {} 65 | for i, w in enumerate(opt.sentiment_categories): 66 | senti_label2idx[w] = i 67 | 68 | print('====> process senti_corpus begin') 69 | for senti in senti_captions: 70 | senti_captions[senti] = [c[0] for c in senti_captions[senti]] 71 | random.shuffle(senti_captions[senti]) 72 | tmp_senti_captions = {'train': {}, 'val': {}} 73 | tmp_senti_captions['train']['neutral'] = deepcopy(senti_captions['neutral'][5000:]) 74 | tmp_senti_captions['val']['neutral'] = deepcopy(senti_captions['neutral'][:5000]) 75 | tmp_senti_captions['train']['positive'] = deepcopy(senti_captions['positive'][1000:]) 76 | tmp_senti_captions['val']['positive'] = deepcopy(senti_captions['positive'][:1000]) 77 | tmp_senti_captions['train']['negative'] = deepcopy(senti_captions['negative'][1000:]) 78 | tmp_senti_captions['val']['negative'] = deepcopy(senti_captions['negative'][:1000]) 79 | tmp_senti_captions['train']['positive'] = tmp_senti_captions['train']['positive'] * int(len(tmp_senti_captions['train']['neutral']) / len(tmp_senti_captions['train']['positive'])) 80 | tmp_senti_captions['train']['negative'] = tmp_senti_captions['train']['negative'] * int(len(tmp_senti_captions['train']['neutral']) / len(tmp_senti_captions['train']['negative'])) 81 | senti_captions = tmp_senti_captions 82 | 83 | train_set = [] 84 | val_set = {} 85 | for senti in opt.sentiment_categories: 86 | print('convert %s corpus to index' % senti) 87 | senti_id = senti_label2idx[senti] 88 | for cap in tqdm.tqdm(senti_captions['train'][senti]): 89 | tmp = [word2idx.get(w, None) or word2idx[''] for w in cap] + [word2idx['']] 90 | train_set.append([senti_id, tmp]) 91 | val_set[senti] = [] 92 | for cap in tqdm.tqdm(senti_captions['val'][senti]): 93 | tmp = [word2idx.get(w, None) or word2idx[''] for w in cap] + [word2idx['']] 94 | val_set[senti].append([senti_id, tmp]) 95 | random.shuffle(train_set) 96 | print('====> process senti_corpus end') 97 | 98 | train_data = get_senti_sents_dataloader(train_set, word2idx[''], opt.max_seq_len) 99 | val_data = {} 100 | for senti in val_set: 101 | val_data[senti] = get_senti_sents_dataloader(val_set[senti], word2idx[''], opt.max_seq_len, shuffle=False) 102 | 103 | checkpoint = os.path.join(opt.checkpoint, 'sent_senti_cls', dataset_name, corpus_type) 104 | if not os.path.exists(checkpoint): 105 | os.makedirs(checkpoint) 106 | result_dir = os.path.join(opt.result_dir, 'sent_senti_cls', dataset_name, corpus_type) 107 | if not os.path.exists(result_dir): 108 | os.makedirs(result_dir) 109 | previous_acc_rate = None 110 | for epoch in range(30): 111 | print('--------------------epoch: %d' % epoch) 112 | model.train() 113 | train_loss = 0.0 114 | for sentis, (caps_tensor, lengths) in tqdm.tqdm(train_data): 115 | sentis = sentis.to(opt.device) 116 | caps_tensor = caps_tensor.to(opt.device) 117 | 118 | pred, _ = model(caps_tensor, lengths) 119 | loss = criterion(pred, sentis) 120 | train_loss += loss.item() 121 | 122 | optimizer.zero_grad() 123 | loss.backward() 124 | clip_gradient(optimizer, opt.grad_clip) 125 | optimizer.step() 126 | train_loss /= len(train_data) 127 | 128 | model.eval() 129 | all_num = defaultdict(int) 130 | senti_num = {} 131 | test_case = defaultdict(list) 132 | for senti in opt.sentiment_categories: 133 | senti_num[senti] = defaultdict(int) 134 | with torch.no_grad(): 135 | for senti, data in val_data.items(): 136 | for sentis, (caps_tensor, lengths) in tqdm.tqdm(data): 137 | sentis = sentis.to(opt.device) 138 | caps_tensor = caps_tensor.to(opt.device) 139 | 140 | rest, rest_w, att_weights = model.sample(caps_tensor, lengths) 141 | rest = torch.LongTensor(np.array(rest)).to(opt.device) 142 | total_num = int(sentis.size(0)) 143 | wrong_num = int((sentis != rest).sum()) 144 | all_num['total_num'] += total_num 145 | all_num['wrong_num'] += wrong_num 146 | senti_num[senti]['total_num'] += total_num 147 | senti_num[senti]['wrong_num'] += wrong_num 148 | 149 | random_id = random.randint(0, caps_tensor.size(0)-1) 150 | caption = ' '.join([idx2word[idx] for idx in caps_tensor[random_id]]) 151 | pred_senti = rest_w[random_id] 152 | att_weight = str(att_weights[random_id].detach().cpu().numpy().tolist()) 153 | test_case[senti].append([caption, pred_senti, att_weight]) 154 | 155 | tmp_total_num = 0 156 | tmp_wrong_num = 0 157 | for senti in senti_num: 158 | tmp_total_num += senti_num[senti]['total_num'] 159 | tmp_wrong_num += senti_num[senti]['wrong_num'] 160 | assert tmp_total_num == all_num['total_num'] and tmp_wrong_num == all_num['wrong_num'] 161 | 162 | all_acc_rate = 100 - all_num['wrong_num'] / all_num['total_num'] * 100 163 | senti_acc_rate = {} 164 | for senti in senti_num: 165 | senti_acc_rate[senti] = 100 - senti_num[senti]['wrong_num'] / senti_num[senti]['total_num'] * 100 166 | 167 | json.dump(test_case, open(os.path.join(result_dir, 'test_case_%d_%.4f.json' % (epoch, all_acc_rate)), 'w')) 168 | 169 | if previous_acc_rate is not None and all_acc_rate < previous_acc_rate: 170 | lr = lr * 0.5 171 | for param_group in optimizer.param_groups: 172 | param_group['lr'] = lr 173 | previous_acc_rate = all_acc_rate 174 | 175 | print('train_loss: %.4f, all_acc_rate: %.4f, senti_acc_rate: %s' % 176 | (train_loss, all_acc_rate, senti_acc_rate)) 177 | if epoch > -1: 178 | chkpoint = { 179 | 'epoch': epoch, 180 | 'model': model.state_dict(), 181 | 'optimizer': optimizer.state_dict(), 182 | 'settings': opt.settings, 183 | 'idx2word': idx2word, 184 | 'sentiment_categories': opt.sentiment_categories, 185 | 'dataset_name': dataset_name, 186 | 'corpus_type': corpus_type, 187 | } 188 | checkpoint_path = os.path.join(checkpoint, 'model_%d_%.4f_%.4f_%s.pth' % ( 189 | epoch, train_loss, all_acc_rate, time.strftime('%m%d-%H%M'))) 190 | torch.save(chkpoint, checkpoint_path) 191 | 192 | 193 | if __name__ == '__main__': 194 | try: 195 | opt = parse_opt() 196 | train() 197 | except BdbQuit: 198 | sys.exit(1) 199 | except Exception: 200 | traceback.print_exc() 201 | print('') 202 | pdb.post_mortem() 203 | sys.exit(1) 204 | -------------------------------------------------------------------------------- /train_senti.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import tqdm 3 | import os 4 | import h5py 5 | import time 6 | import json 7 | import sys 8 | import pdb 9 | import traceback 10 | from bdb import BdbQuit 11 | import torch 12 | 13 | from opts import parse_opt 14 | from models.sentiment_detector import SentimentDetector 15 | from dataloader import get_senti_image_dataloader 16 | 17 | 18 | def clip_gradient(optimizer, grad_clip): 19 | for group in optimizer.param_groups: 20 | for param in group['params']: 21 | param.grad.data.clamp_(-grad_clip, grad_clip) 22 | 23 | 24 | def train(): 25 | senti_detector = SentimentDetector(opt.sentiment_categories, opt.settings) 26 | senti_detector.to(opt.device) 27 | lr = opt.senti_lr 28 | optimizer, criterion = senti_detector.get_optim_criterion(lr) 29 | if opt.senti_resume: 30 | print("====> loading checkpoint '{}'".format(opt.senti_resume)) 31 | chkpoint = torch.load(opt.senti_resume, map_location=lambda s, l: s) 32 | assert opt.settings == chkpoint['settings'], \ 33 | 'opt.settings and resume model settings are different' 34 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 35 | 'sentiment_categories and resume model sentiment_categories are different' 36 | senti_detector.load_state_dict(chkpoint['model']) 37 | optimizer.load_state_dict(chkpoint['optimizer']) 38 | lr = optimizer.param_groups[0]['lr'] 39 | print("====> loaded checkpoint '{}', epoch: {}" 40 | .format(opt.senti_resume, chkpoint['epoch'])) 41 | 42 | img_senti_labels = json.load(open(opt.img_senti_labels, 'r')) 43 | 44 | senti_label2idx = {} 45 | for i, w in enumerate(opt.sentiment_categories): 46 | senti_label2idx[w] = i 47 | print('====> process image senti_labels begin') 48 | senti_labels_id = {} 49 | for split, senti_labels in img_senti_labels.items(): 50 | print('convert %s senti_labels to index' % split) 51 | senti_labels_id[split] = [] 52 | for fn, senti_label in tqdm.tqdm(senti_labels): 53 | senti_labels_id[split].append([fn, senti_label2idx[senti_label]]) 54 | img_senti_labels = senti_labels_id 55 | print('====> process image senti_labels end') 56 | 57 | f_senti_att = os.path.join(opt.feats_dir, 'sentiment', 'feats_att.h5') 58 | train_data = get_senti_image_dataloader( 59 | f_senti_att, img_senti_labels['train'], 60 | opt.senti_bs, opt.senti_num_works) 61 | val_data = get_senti_image_dataloader( 62 | f_senti_att, img_senti_labels['val'], 63 | opt.senti_bs, opt.senti_num_works, shuffle=False) 64 | test_data = get_senti_image_dataloader( 65 | f_senti_att, img_senti_labels['test'], 66 | opt.senti_bs, opt.senti_num_works, shuffle=False) 67 | 68 | def forward(data, training=True): 69 | senti_detector.train(training) 70 | loss_val = 0.0 71 | for _, att_feats, labels in tqdm.tqdm(data): 72 | att_feats = att_feats.to(opt.device) 73 | labels = labels.to(opt.device) 74 | # (det_out, cls_out), _ = senti_detector(att_feats) 75 | # det_loss = criterion(det_out, labels) 76 | # cls_loss = criterion(cls_out, labels) 77 | # loss = det_loss + cls_loss 78 | pred, _ = senti_detector(att_feats) 79 | loss = criterion(pred, labels) 80 | loss_val += loss.item() 81 | if training: 82 | optimizer.zero_grad() 83 | loss.backward() 84 | clip_gradient(optimizer, opt.grad_clip) 85 | optimizer.step() 86 | return loss_val / len(data) 87 | 88 | checkpoint = os.path.join(opt.checkpoint, 'sentiment') 89 | if not os.path.exists(checkpoint): 90 | os.makedirs(checkpoint) 91 | previous_loss = None 92 | for epoch in range(opt.senti_epochs): 93 | print('--------------------epoch: %d' % epoch) 94 | # torch.cuda.empty_cache() 95 | train_loss = forward(train_data) 96 | with torch.no_grad(): 97 | val_loss = forward(val_data, training=False) 98 | 99 | # test 100 | corr_num = 0 101 | all_num = 0 102 | for _, att_feats, labels in tqdm.tqdm(test_data): 103 | att_feats = att_feats.to(opt.device) 104 | labels = labels.to(opt.device) 105 | idx, _, _, _ = senti_detector.sample(att_feats) 106 | corr_num += int(sum(labels == idx)) 107 | all_num += len(idx) 108 | corr_rate = corr_num / all_num 109 | 110 | if previous_loss is not None and val_loss > previous_loss: 111 | lr = lr * 0.5 112 | for param_group in optimizer.param_groups: 113 | param_group['lr'] = lr 114 | previous_loss = val_loss 115 | 116 | print('train_loss: %.4f, val_loss: %.4f, corr_rate: %.4f' % 117 | (train_loss, val_loss, corr_rate)) 118 | if epoch == 0 or epoch > 5: 119 | chkpoint = { 120 | 'epoch': epoch, 121 | 'model': senti_detector.state_dict(), 122 | 'optimizer': optimizer.state_dict(), 123 | 'settings': opt.settings, 124 | 'sentiment_categories': opt.sentiment_categories, 125 | } 126 | checkpoint_path = os.path.join(checkpoint, 'model_%d_%.4f_%.4f_%s.pth' % ( 127 | epoch, train_loss, val_loss, time.strftime('%m%d-%H%M'))) 128 | torch.save(chkpoint, checkpoint_path) 129 | 130 | 131 | if __name__ == '__main__': 132 | try: 133 | opt = parse_opt() 134 | train() 135 | except BdbQuit: 136 | sys.exit(1) 137 | except Exception: 138 | traceback.print_exc() 139 | print('') 140 | pdb.post_mortem() 141 | sys.exit(1) 142 | -------------------------------------------------------------------------------- /train_xe.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import tqdm 3 | import os 4 | import time 5 | import json 6 | from collections import defaultdict 7 | import sys 8 | import pdb 9 | import traceback 10 | from bdb import BdbQuit 11 | import torch 12 | 13 | from opts import parse_opt 14 | from models.captioner import Captioner 15 | from models.sent_senti_cls import SentenceSentimentClassifier 16 | from dataloader import get_caption_dataloader, get_senti_corpus_with_sentis_dataloader 17 | 18 | 19 | def clip_gradient(optimizer, grad_clip): 20 | for group in optimizer.param_groups: 21 | for param in group['params']: 22 | if param.grad is not None: 23 | param.grad.data.clamp_(-grad_clip, grad_clip) 24 | 25 | 26 | def train(): 27 | dataset_name = opt.dataset_name 28 | corpus_type = opt.corpus_type 29 | 30 | idx2word = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'idx2word.json'), 'r')) 31 | img_captions = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_captions.json'), 'r')) 32 | img_det_concepts = json.load(open(os.path.join(opt.captions_dir, dataset_name, 'img_det_concepts.json'), 'r')) 33 | senti_captions = json.load(open(os.path.join(opt.captions_dir, dataset_name, corpus_type, 'senti_captions.json'), 'r')) 34 | 35 | captioner = Captioner(idx2word, opt.sentiment_categories, opt.settings) 36 | captioner.to(opt.device) 37 | lr = opt.xe_lr 38 | optimizer, xe_crit, da_crit = captioner.get_optim_criterion(lr) 39 | if opt.xe_resume: 40 | print("====> loading checkpoint '{}'".format(opt.xe_resume)) 41 | chkpoint = torch.load(opt.xe_resume, map_location=lambda s, l: s) 42 | assert opt.settings == chkpoint['settings'], \ 43 | 'opt.settings and resume model settings are different' 44 | assert idx2word == chkpoint['idx2word'], \ 45 | 'idx2word and resume model idx2word are different' 46 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 47 | 'sentiment_categories and resume model sentiment_categories are different' 48 | assert dataset_name == chkpoint['dataset_name'], \ 49 | 'dataset_name and resume model dataset_name are different' 50 | assert corpus_type == chkpoint['corpus_type'], \ 51 | 'corpus_type and resume model corpus_type are different' 52 | captioner.load_state_dict(chkpoint['model']) 53 | optimizer.load_state_dict(chkpoint['optimizer']) 54 | lr = optimizer.param_groups[0]['lr'] 55 | print("====> loaded checkpoint '{}', epoch: {}" 56 | .format(opt.xe_resume, chkpoint['epoch'])) 57 | 58 | sent_senti_cls = SentenceSentimentClassifier(idx2word, opt.sentiment_categories, opt.settings) 59 | sent_senti_cls.to(opt.device) 60 | ss_cls_file = os.path.join(opt.checkpoint, 'sent_senti_cls', dataset_name, corpus_type, 'model-best.pth') 61 | print("====> loading checkpoint '{}'".format(ss_cls_file)) 62 | chkpoint = torch.load(ss_cls_file, map_location=lambda s, l: s) 63 | assert opt.settings == chkpoint['settings'], \ 64 | 'opt.settings and resume model settings are different' 65 | assert idx2word == chkpoint['idx2word'], \ 66 | 'idx2word and resume model idx2word are different' 67 | assert opt.sentiment_categories == chkpoint['sentiment_categories'], \ 68 | 'opt.sentiment_categories and resume model sentiment_categories are different' 69 | assert dataset_name == chkpoint['dataset_name'], \ 70 | 'dataset_name and resume model dataset_name are different' 71 | assert corpus_type == chkpoint['corpus_type'], \ 72 | 'corpus_type and resume model corpus_type are different' 73 | sent_senti_cls.load_state_dict(chkpoint['model']) 74 | sent_senti_cls.eval() 75 | 76 | word2idx = {} 77 | for i, w in enumerate(idx2word): 78 | word2idx[w] = i 79 | 80 | print('====> process image captions begin') 81 | captions_id = {} 82 | for split, caps in img_captions.items(): 83 | print('convert %s captions to index' % split) 84 | captions_id[split] = {} 85 | for fn, seqs in tqdm.tqdm(caps.items()): 86 | tmp = [] 87 | for seq in seqs: 88 | tmp.append([captioner.sos_id] + 89 | [word2idx.get(w, None) or word2idx[''] for w in seq] + 90 | [captioner.eos_id]) 91 | captions_id[split][fn] = tmp 92 | img_captions = captions_id 93 | print('====> process image captions end') 94 | 95 | print('====> process image det_concepts begin') 96 | det_concepts_id = {} 97 | for fn, cpts in tqdm.tqdm(img_det_concepts.items()): 98 | det_concepts_id[fn] = [word2idx[w] for w in cpts] 99 | img_det_concepts = det_concepts_id 100 | print('====> process image det_concepts end') 101 | 102 | senti_label2idx = {} 103 | for i, w in enumerate(opt.sentiment_categories): 104 | senti_label2idx[w] = i 105 | print('====> process senti corpus begin') 106 | senti_captions['positive'] = senti_captions['positive'] * int(len(senti_captions['neutral']) / len(senti_captions['positive'])) 107 | senti_captions['negative'] = senti_captions['negative'] * int(len(senti_captions['neutral']) / len(senti_captions['negative'])) 108 | senti_captions_id = [] 109 | for senti, caps in senti_captions.items(): 110 | print('convert %s corpus to index' % senti) 111 | senti_id = senti_label2idx[senti] 112 | for cap, cpts, sentis in tqdm.tqdm(caps): 113 | cap = [captioner.sos_id] +\ 114 | [word2idx.get(w, None) or word2idx[''] for w in cap] +\ 115 | [captioner.eos_id] 116 | cpts = [word2idx[w] for w in cpts if w in word2idx] 117 | sentis = [word2idx[w] for w in sentis] 118 | senti_captions_id.append([cap, cpts, sentis, senti_id]) 119 | senti_captions = senti_captions_id 120 | print('====> process senti corpus end') 121 | 122 | fc_feats = os.path.join(opt.feats_dir, dataset_name, '%s_fc.h5' % dataset_name) 123 | att_feats = os.path.join(opt.feats_dir, dataset_name, '%s_att.h5' % dataset_name) 124 | train_data = get_caption_dataloader(fc_feats, att_feats, img_captions['train'], 125 | img_det_concepts, idx2word.index(''), 126 | opt.max_seq_len, opt.num_concepts, 127 | opt.xe_bs, opt.xe_num_works) 128 | val_data = get_caption_dataloader(fc_feats, att_feats, img_captions['val'], 129 | img_det_concepts, idx2word.index(''), 130 | opt.max_seq_len, opt.num_concepts, opt.xe_bs, 131 | opt.xe_num_works, shuffle=False) 132 | scs_data = get_senti_corpus_with_sentis_dataloader( 133 | senti_captions, idx2word.index(''), opt.max_seq_len, 134 | opt.num_concepts, opt.num_sentiments, 80, opt.xe_num_works) 135 | 136 | test_captions = {} 137 | for fn in img_captions['test']: 138 | test_captions[fn] = [[]] 139 | test_data = get_caption_dataloader(fc_feats, att_feats, test_captions, 140 | img_det_concepts, idx2word.index(''), 141 | opt.max_seq_len, opt.num_concepts, opt.xe_bs, 142 | opt.xe_num_works, shuffle=False) 143 | 144 | def forward(data, training=True, ss_prob=0.0): 145 | captioner.train(training) 146 | if training: 147 | seq2seq_data = iter(scs_data) 148 | loss_val = defaultdict(float) 149 | for _, fc_feats, att_feats, (caps_tensor, lengths), cpts_tensor in tqdm.tqdm(data): 150 | fc_feats = fc_feats.to(opt.device) 151 | att_feats = att_feats.to(opt.device) 152 | caps_tensor = caps_tensor.to(opt.device) 153 | cpts_tensor = cpts_tensor.to(opt.device) 154 | 155 | with torch.no_grad(): 156 | xe_senti_labels, _ = sent_senti_cls(caps_tensor[:, 1:], lengths) 157 | xe_senti_labels = xe_senti_labels.softmax(dim=-1) 158 | xe_senti_labels = xe_senti_labels.argmax(dim=-1).detach() 159 | 160 | pred = captioner(fc_feats, att_feats, cpts_tensor, caps_tensor, 161 | xe_senti_labels, ss_prob, mode='xe') 162 | xe_loss = xe_crit(pred, caps_tensor[:, 1:], lengths) 163 | da_loss = da_crit(captioner.cpt_feats, captioner.fc_feats.detach()) 164 | cap_loss = xe_loss + da_loss 165 | loss_val['xe_loss'] += float(xe_loss) 166 | loss_val['da_loss'] += float(da_loss) 167 | loss_val['cap_loss'] += float(cap_loss) 168 | 169 | seq2seq_loss = 0.0 170 | if training: 171 | try: 172 | (caps_tensor, lengths), cpts_tensor, sentis_tensor, senti_labels = next(seq2seq_data) 173 | except: 174 | seq2seq_data = iter(scs_data) 175 | (caps_tensor, lengths), cpts_tensor, sentis_tensor, senti_labels = next(seq2seq_data) 176 | caps_tensor = caps_tensor.to(opt.device) 177 | cpts_tensor = cpts_tensor.to(opt.device) 178 | sentis_tensor = sentis_tensor.to(opt.device) 179 | senti_labels = senti_labels.to(opt.device) 180 | pred = captioner(caps_tensor, cpts_tensor, sentis_tensor, senti_labels, 181 | ss_prob, mode='seq2seq') 182 | seq2seq_loss = xe_crit(pred, caps_tensor[:, 1:], lengths) 183 | loss_val['seq2seq_loss'] += float(seq2seq_loss) 184 | 185 | all_loss = cap_loss + seq2seq_loss 186 | loss_val['all_loss'] += float(all_loss) 187 | 188 | if training: 189 | optimizer.zero_grad() 190 | all_loss.backward() 191 | clip_gradient(optimizer, opt.grad_clip) 192 | optimizer.step() 193 | 194 | for k, v in loss_val.items(): 195 | loss_val[k] = v / len(data) 196 | return loss_val 197 | 198 | tmp_dir = '' 199 | checkpoint = os.path.join(opt.checkpoint, 'xe', dataset_name, corpus_type, tmp_dir) 200 | if not os.path.exists(checkpoint): 201 | os.makedirs(checkpoint) 202 | result_dir = os.path.join(opt.result_dir, 'xe', dataset_name, corpus_type, tmp_dir) 203 | if not os.path.exists(result_dir): 204 | os.makedirs(result_dir) 205 | previous_loss = None 206 | for epoch in range(opt.xe_epochs): 207 | print('--------------------epoch: %d' % epoch) 208 | # torch.cuda.empty_cache() 209 | ss_prob = 0.0 210 | if epoch > opt.scheduled_sampling_start >= 0: 211 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every 212 | ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) 213 | print('tmp_dir:', tmp_dir, 'ss_prob:', ss_prob) 214 | train_loss = forward(train_data, ss_prob=ss_prob) 215 | 216 | with torch.no_grad(): 217 | val_loss = forward(val_data, training=False) 218 | 219 | results = [] 220 | fact_txt = '' 221 | for fns, fc_feats, att_feats, _, _ in tqdm.tqdm(test_data): 222 | fc_feats = fc_feats.to(opt.device) 223 | att_feats = att_feats.to(opt.device) 224 | for i, fn in enumerate(fns): 225 | captions, _ = captioner.sample( 226 | fc_feats[i], att_feats[i], 227 | beam_size=opt.beam_size, max_seq_len=opt.max_seq_len) 228 | results.append({'image_id': fn, 'caption': captions[0]}) 229 | fact_txt += captions[0] + '\n' 230 | json.dump(results, open(os.path.join(result_dir, 'result_%d.json' % epoch), 'w')) 231 | with open(os.path.join(result_dir, 'result_%d.txt' % epoch), 'w') as f: 232 | f.write(fact_txt) 233 | 234 | if previous_loss is not None and val_loss['all_loss'] > previous_loss: 235 | lr = lr * 0.5 236 | for param_group in optimizer.param_groups: 237 | param_group['lr'] = lr 238 | previous_loss = val_loss['all_loss'] 239 | 240 | print('train_loss: %s, val_loss: %s' % (dict(train_loss), dict(val_loss))) 241 | if epoch in [0, 10, 15, 20, 25, 29, 30, 35, 39]: 242 | chkpoint = { 243 | 'epoch': epoch, 244 | 'model': captioner.state_dict(), 245 | 'optimizer': optimizer.state_dict(), 246 | 'settings': opt.settings, 247 | 'idx2word': idx2word, 248 | 'sentiment_categories': opt.sentiment_categories, 249 | 'dataset_name': dataset_name, 250 | 'corpus_type': corpus_type, 251 | } 252 | checkpoint_path = os.path.join(checkpoint, 'model_%d_%.4f_%.4f_%s.pth' % ( 253 | epoch, train_loss['all_loss'], val_loss['all_loss'], time.strftime('%m%d-%H%M'))) 254 | torch.save(chkpoint, checkpoint_path) 255 | 256 | 257 | if __name__ == '__main__': 258 | try: 259 | opt = parse_opt() 260 | train() 261 | except BdbQuit: 262 | sys.exit(1) 263 | except Exception: 264 | traceback.print_exc() 265 | print('') 266 | pdb.post_mortem() 267 | sys.exit(1) 268 | --------------------------------------------------------------------------------