├── .gitmodules ├── ADVANCED.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── captioning ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── dataloaderraw.py │ └── pth_loader.py ├── models │ ├── AoAModel.py │ ├── AttEnsemble.py │ ├── AttModel.py │ ├── BertCapModel.py │ ├── CaptionModel.py │ ├── FCModel.py │ ├── M2Transformer.py │ ├── ShowTellModel.py │ ├── TransformerModel.py │ ├── __init__.py │ ├── cachedTransformer.py │ └── utils.py ├── modules │ ├── loss_wrapper.py │ └── losses.py └── utils │ ├── __init__.py │ ├── config.py │ ├── div_utils.py │ ├── eval_multi.py │ ├── eval_utils.py │ ├── misc.py │ ├── opts.py │ ├── resnet.py │ ├── resnet_utils.py │ └── rewards.py ├── configs ├── a2i2.yml ├── a2i2_nsc.yml ├── a2i2_sc.yml ├── aoa.yml ├── aoa_nsc.yml ├── aoa_sc.yml ├── fc.yml ├── fc_nsc.yml ├── fc_rl.yml ├── transformer │ ├── transformer.yml │ ├── transformer_nsc.yml │ ├── transformer_nscl.yml │ ├── transformer_sc.yml │ ├── transformer_scl.yml │ └── transformer_step.yml └── updown │ ├── ud_long_nsc.yml │ ├── ud_long_sc.yml │ ├── updown.yml │ ├── updown_long.yml │ ├── updown_nsc.yml │ └── updown_sc.yml ├── data └── README.md ├── projects ├── Diversity │ ├── README.md │ └── scripts │ │ ├── eval_scripts │ │ ├── only_eval_test_n_bs.sh │ │ ├── only_eval_test_n_dbst.sh │ │ ├── only_eval_test_n_sp.sh │ │ ├── only_eval_test_n_topk.sh │ │ ├── only_eval_test_n_topp.sh │ │ ├── only_gen_test_n_bs.sh │ │ ├── only_gen_test_n_dbst.sh │ │ ├── only_gen_test_n_sp.sh │ │ ├── only_gen_test_n_topk.sh │ │ └── only_gen_test_n_topp.sh │ │ └── train_scripts │ │ ├── run_a2i2.sh │ │ ├── run_a2i2_npg.sh │ │ ├── run_a2i2_pgg.sh │ │ ├── run_a2i2_sf_npg.sh │ │ ├── run_a2i2l.sh │ │ ├── run_a2i2l_npg.sh │ │ ├── run_a2i2l_sf_npg.sh │ │ ├── run_fc.sh │ │ ├── run_fc_npg.sh │ │ ├── run_td.sh │ │ ├── run_td_npg.sh │ │ ├── run_transf.sh │ │ ├── run_transf_npg.sh │ │ └── run_transf_sf_npg.sh └── NewSelfCritical │ └── README.md ├── scripts ├── build_bpe_subword_nmt.py ├── copy_model.sh ├── dump_to_h5df.py ├── dump_to_lmdb.py ├── make_bu_data.py ├── prepro_feats.py ├── prepro_labels.py ├── prepro_ngrams.py └── prepro_reference_json.py ├── setup.py ├── test └── test_pth_loader.py ├── tools ├── eval.py ├── eval_ensemble.py ├── train.py └── train_pl.py └── vis ├── imgs └── dummy ├── index.html └── jquery-1.8.3.min.js /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cider"] 2 | path = cider 3 | url = https://github.com/ruotianluo/cider.git 4 | [submodule "coco-caption"] 5 | path = coco-caption 6 | url = https://github.com/ruotianluo/coco-caption.git 7 | -------------------------------------------------------------------------------- /ADVANCED.md: -------------------------------------------------------------------------------- 1 | # Advanced 2 | 3 | ## Ensemble 4 | 5 | Current ensemble only supports models which are subclass of AttModel. Here is example of the script to run ensemble models. The `eval_ensemble.py` assumes the model saving under `log_$id`. 6 | 7 | ``` 8 | python eval_ensemble.py --dump_json 0 --ids model1 model2 model3 --weights 0.3 0.3 0.3 --batch_size 1 --dump_images 0 --num_images 5000 --split test --language_eval 1 --beam_size 5 --temperature 1.0 --sample_method greedy --max_length 30 9 | ``` 10 | 11 | ## BPE 12 | 13 | ``` 14 | python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk_bpe.json --output_h5 data/cocotalk_bpe --symbol_count 6000 15 | ``` 16 | 17 | Doesn't seem to help improve performance. 18 | 19 | ## Use lmdb instead of a folder of countless files 20 | 21 | It's known that some file systems do not like a folder with a lot of single files. However, in this project, the default way of saving precomputed image features is to save each image feature as an individual file. 22 | 23 | Usually, for COCO, once all the features have been cached on the memory (basically after the first epoch), then the time for reading data is negligible. However, for much larger dataset like ConceptualCaptioning, since the number of features is too large and the memory cannot fit all image features, this results in extremely slow data loading and is always slow even passing one epoch. 24 | 25 | For that dataset, I used lmdb to save all the features. Although it is still slow to load the data, it's much better compared to saving individual files. 26 | 27 | To generate lmdb file from a folder of features, check out `scripts/dump_to_lmdb.py` which is borrowed from [Lyken17/Efficient-PyTorch](https://github.com/Lyken17/Efficient-PyTorch/tools). 28 | 29 | I believe the current way of using lmdb in `dataloader.py` is far from optimal. I tried methods in tensorpack but failed to make it work. (The idea was to ready by chunk, so that the lmdb loading can load a chunk at a time, reducing the time for ad hoc disk visiting.) 30 | 31 | ## new self critical 32 | 33 | This "new self critical" is borrowed from "Variational inference for monte carlo objectives". The only difference from the original self critical, is the definition of baseline. 34 | 35 | In the original self critical, the baseline is the score of greedy decoding output. In new self critical, the baseline is the average score of the other samples (this requires the model to generate multiple samples for each image). 36 | 37 | To try self critical on updown model, you can run 38 | 39 | `python train.py --cfg configs/updown_nsc.yml` 40 | 41 | This yml file can also provides you some hint what to change to use new self critical. 42 | 43 | # SCST in Topdown Bottomup paper 44 | 45 | In Topdown bottomup paper, instead of random sampling when SCST, they use beam search. To do so, you can try: 46 | 47 | `python train.py --id fc_tdsc --cfg configs/fc_rl.yml --train_sample_method greedy --train_beam_size 5 --max_epochs 30 --learning_rate 5e-6` 48 | 49 | ## Sample n captions 50 | 51 | When sampling, set `sample_n` to be greater than 0. 52 | 53 | ## Batch normalization 54 | 55 | ## Box feature 56 | 57 | ## Training with pytorch lightning 58 | To run it, you need to install pytorch-lightning, as well as detectron2(for its utility functions). 59 | 60 | The benefit of pytorch-lightning is I don't need to take care of the distributed data parallel details. (Although I know how to do this, but it seems lightning is really convenient. Nevertheless I hate the idea that LightningModule is a nn.Module.) 61 | 62 | Training script (in fact it's almost identical): 63 | ``` 64 | python tools/train_pl.py --id trans --cfg configs/transformer.yml 65 | ``` 66 | 67 | Test script: 68 | ``` 69 | EVALUATE=1 python tools/train_pl.py --id trans --cfg configs/transformer.yml 70 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ruotian(RT) Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | Results are on karpathy test split, beam size 5. The evaluated models are the checkpoint with the highest CIDEr on validation set. Without notice, the numbers shown are not selected. The scores are just used to verify if you are getting things right. If the scores you get is close to the number I give (it could be higher or lower), then it's ok. 4 | 5 | # Trained with Resnet101 feature: 6 | 7 | Collection: [link](https://drive.google.com/open?id=0B7fNdx_jAqhtcXp0aFlWSnJmb0k) 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
NameCIDErSPICEDownloadNote
FC0.9530.1787model&metrics--caption_model newfc
FC
+self_critical
1.0450.1838model&metrics--caption_model newfc
FC
+new_self_critical
1.0530.1857model&metrics--caption_model newfc
37 | 38 | # Trained with Bottomup feature (10-100 features per image, not 36 features per image): 39 | 40 | Collection: [link](https://drive.google.com/open?id=1-RNak8qLUR5LqfItY6OenbRl8sdwODng) 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 |
NameCIDErSPICEDownloadNote
Att2in1.0890.1982model&metricsMy replication
Att2in
+self_critical
1.1730.2046model&metrics
Att2in
+new_self_critical
1.1950.2066model&metrics
UpDown1.0990.1999model&metricsMy replication
UpDown
+self_critical
1.2270.2145model&metrics
UpDown
+new_self_critical
1.2390.2154model&metrics
UpDown
+Schedule long
+new_self_critical
1.2800.2200model&metricsBest of 5 models
schedule proposed by yangxuntu
Transformer1.12590.2063model&metrics
Transformer(warmup+step decay)1.14960.2093model&metricsAlthough this schedule is better, the final self critical results are similar.
Transformer
+self_critical
1.2770.2249model&metricsThis could be higher in my opinion. I chose the checkpoint with the highest CIDEr on val set, so it's possible some other checkpoint may perform better. Just let you know.
Transformer
+new_self_critical
1.3030.2289model&metrics
118 | 119 | 120 | # Trained with vilbert-12-in-1 feature: 121 | 122 | Collection: [link](https://drive.google.com/drive/folders/1QdqRGUoPoQChOq65ecIaSl1yXOosSQm3?usp=sharing) 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 145 |
NameCIDErSPICEDownloadNote
Transformer1.1580.2114model&metricsThe config needs to be changed to use the vilbert feature.
146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Image Captioning codebase 2 | 3 | This is a codebase for image captioning research. 4 | 5 | It supports: 6 | - Self critical training from [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) 7 | - Bottom up feature from [ref](https://arxiv.org/abs/1707.07998). 8 | - Test time ensemble 9 | - Multi-GPU training. (DistributedDataParallel is now supported with the help of pytorch-lightning, see [ADVANCED.md](ADVANCED.md) for details) 10 | - Transformer captioning model. 11 | 12 | A simple demo colab notebook is available [here](https://colab.research.google.com/github/ruotianluo/ImageCaptioning.pytorch/blob/colab/notebooks/captioning_demo.ipynb) 13 | 14 | ## Requirements 15 | - Python 3 16 | - PyTorch 1.3+ (along with torchvision) (Test with 1.13) 17 | - cider (already been added as a submodule) 18 | - coco-caption (already been added as a submodule) (**Remember to follow initialization steps in coco-caption/README.md**) 19 | - yacs 20 | - lmdbdict 21 | - Optional: pytorch-lightning (Tested with 2.0) 22 | 23 | ## Install 24 | 25 | If you have difficulty running the training scripts in `tools`. You can try installing this repo as a python package: 26 | ``` 27 | python -m pip install -e . 28 | ``` 29 | 30 | ## Pretrained models 31 | 32 | Checkout [MODEL_ZOO.md](MODEL_ZOO.md). 33 | 34 | If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101 or precomputed bottomup features, see [data/README.md](data/README.md)). 35 | 36 | ## Train your own network on COCO/Flickr30k 37 | 38 | ### Prepare data. 39 | 40 | We now support both flickr30k and COCO. See details in [data/README.md](data/README.md). (Note: the later sections assume COCO dataset; it should be trivial to use flickr30k.) 41 | 42 | ### Start training 43 | 44 | ```bash 45 | $ python tools/train.py --id fc --caption_model newfc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_fc --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 30 46 | ``` 47 | 48 | or 49 | 50 | ```bash 51 | $ python tools/train.py --cfg configs/fc.yml --id fc 52 | ``` 53 | 54 | The train script will dump checkpoints into the folder specified by `--checkpoint_path` (default = `log_$id/`). By default only save the best-performing checkpoint on validation and the latest checkpoint to save disk space. You can also set `--save_history_ckpt` to 1 to save every checkpoint. 55 | 56 | To resume training, you can specify `--start_from` option to be the path saving `infos.pkl` and `model.pth` (usually you could just set `--start_from` and `--checkpoint_path` to be the same). 57 | 58 | To checkout the training curve or validation curve, you can use tensorboard. The loss histories are automatically dumped into `--checkpoint_path`. 59 | 60 | The current command use scheduled sampling, you can also set `--scheduled_sampling_start` to -1 to turn off scheduled sampling. 61 | 62 | If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use `--language_eval 1` option, but don't forget to pull the submodule `coco-caption`. 63 | 64 | For all the arguments, you can specify them in a yaml file and use `--cfg` to use the configurations in that yaml file. The configurations in command line will overwrite cfg file if there are conflicts. 65 | 66 | For more options, see `opts.py`. 67 | 68 | 69 | 70 | ### Train using self critical 71 | 72 | First you should preprocess the dataset and get the cache for calculating cider score: 73 | ``` 74 | $ python scripts/prepro_ngrams.py --input_json data/dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train 75 | ``` 76 | 77 | Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up) 78 | ``` 79 | $ bash scripts/copy_model.sh fc fc_rl 80 | ``` 81 | 82 | Then 83 | ```bash 84 | $ python tools/train.py --id fc_rl --caption_model newfc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 --cached_tokens coco-train-idxs --max_epoch 50 --train_sample_n 5 85 | ``` 86 | 87 | or 88 | ```bash 89 | $ python tools/train.py --cfg configs/fc_rl.yml --id fc_rl 90 | ``` 91 | 92 | 93 | You will see a huge boost on Cider score, : ). 94 | 95 | **A few notes on training.** Starting self-critical training after 30 epochs, the CIDEr score goes up to 1.05 after 600k iterations (including the 30 epochs pertraining). 96 | 97 | ## Generate image captions 98 | 99 | ### Evaluate on raw images 100 | 101 | **Note**: this doesn't work for models trained with bottomup feature. 102 | Now place all your images of interest into a folder, e.g. `blah`, and run 103 | the eval script: 104 | 105 | ```bash 106 | $ python tools/eval.py --model model.pth --infos_path infos.pkl --image_folder blah --num_images 10 107 | ``` 108 | 109 | This tells the `eval` script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing `batch_size`. Use `--num_images -1` to process all images. The eval script will create an `vis.json` file inside the `vis` folder, which can then be visualized with the provided HTML interface: 110 | 111 | ```bash 112 | $ cd vis 113 | $ python -m SimpleHTTPServer 114 | ``` 115 | 116 | Now visit `localhost:8000` in your browser and you should see your predicted captions. 117 | 118 | ### Evaluate on Karpathy's test split 119 | 120 | ```bash 121 | $ python tools/eval.py --dump_images 0 --num_images 5000 --model model.pth --infos_path infos.pkl --language_eval 1 122 | ``` 123 | 124 | The defualt split to evaluate is test. The default inference method is greedy decoding (`--sample_method greedy`), to sample from the posterior, set `--sample_method sample`. 125 | 126 | **Beam Search**. Beam search can increase the performance of the search for greedy decoding sequence by ~5%. However, this is a little more expensive. To turn on the beam search, use `--beam_size N`, N should be greater than 1. 127 | 128 | ### Evaluate on COCO test set 129 | 130 | ```bash 131 | $ python tools/eval.py --input_json cocotest.json --input_fc_dir data/cocotest_bu_fc --input_att_dir data/cocotest_bu_att --input_label_h5 none --num_images -1 --model model.pth --infos_path infos.pkl --language_eval 0 132 | ``` 133 | 134 | You can download the preprocessed file `cocotest.json`, `cocotest_bu_att` and `cocotest_bu_fc` from [link](https://drive.google.com/open?id=1eCdz62FAVCGogOuNhy87Nmlo5_I0sH2J). 135 | 136 | ## Miscellanea 137 | **Using cpu**. The code is currently defaultly using gpu; there is even no option for switching. If someone highly needs a cpu model, please open an issue; I can potentially create a cpu checkpoint and modify the eval.py to run the model on cpu. However, there's no point using cpus to train the model. 138 | 139 | **Train on other dataset**. It should be trivial to port if you can create a file like `dataset_coco.json` for your own dataset. 140 | 141 | **Live demo**. Not supported now. Welcome pull request. 142 | 143 | ## For more advanced features: 144 | 145 | Checkout [ADVANCED.md](ADVANCED.md). 146 | 147 | ## Reference 148 | 149 | If you find this repo useful, please consider citing (no obligation at all): 150 | 151 | ``` 152 | @article{luo2018discriminability, 153 | title={Discriminability objective for training descriptive captions}, 154 | author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory}, 155 | journal={arXiv preprint arXiv:1803.04376}, 156 | year={2018} 157 | } 158 | ``` 159 | 160 | Of course, please cite the original paper of models you are using (You can find references in the model files). 161 | 162 | ## Acknowledgements 163 | 164 | Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. -------------------------------------------------------------------------------- /captioning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/ImageCaptioning.pytorch/4c48a3304932d58c5349434e7b0085f48dcb4be4/captioning/__init__.py -------------------------------------------------------------------------------- /captioning/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/ImageCaptioning.pytorch/4c48a3304932d58c5349434e7b0085f48dcb4be4/captioning/data/__init__.py -------------------------------------------------------------------------------- /captioning/data/dataloaderraw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | import skimage 12 | import skimage.io 13 | import scipy.misc 14 | 15 | from torchvision import transforms as trn 16 | preprocess = trn.Compose([ 17 | #trn.ToTensor(), 18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | from ..utils.resnet_utils import myResnet 22 | from ..utils import resnet 23 | 24 | class DataLoaderRaw(): 25 | 26 | def __init__(self, opt): 27 | self.opt = opt 28 | self.coco_json = opt.get('coco_json', '') 29 | self.folder_path = opt.get('folder_path', '') 30 | 31 | self.batch_size = opt.get('batch_size', 1) 32 | self.seq_per_img = 1 33 | 34 | # Load resnet 35 | self.cnn_model = opt.get('cnn_model', 'resnet101') 36 | self.my_resnet = getattr(resnet, self.cnn_model)() 37 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) 38 | self.my_resnet = myResnet(self.my_resnet) 39 | self.my_resnet.cuda() 40 | self.my_resnet.eval() 41 | 42 | 43 | 44 | # load the json file which contains additional information about the dataset 45 | print('DataLoaderRaw loading images from folder: ', self.folder_path) 46 | 47 | self.files = [] 48 | self.ids = [] 49 | 50 | print(len(self.coco_json)) 51 | if len(self.coco_json) > 0: 52 | print('reading from ' + opt.coco_json) 53 | # read in filenames from the coco-style json file 54 | self.coco_annotation = json.load(open(self.coco_json)) 55 | for k,v in enumerate(self.coco_annotation['images']): 56 | fullpath = os.path.join(self.folder_path, v['file_name']) 57 | self.files.append(fullpath) 58 | self.ids.append(v['id']) 59 | else: 60 | # read in all the filenames from the folder 61 | print('listing all images in directory ' + self.folder_path) 62 | def isImage(f): 63 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] 64 | for ext in supportedExt: 65 | start_idx = f.rfind(ext) 66 | if start_idx >= 0 and start_idx + len(ext) == len(f): 67 | return True 68 | return False 69 | 70 | n = 1 71 | for root, dirs, files in os.walk(self.folder_path, topdown=False): 72 | for file in files: 73 | fullpath = os.path.join(self.folder_path, file) 74 | if isImage(fullpath): 75 | self.files.append(fullpath) 76 | self.ids.append(str(n)) # just order them sequentially 77 | n = n + 1 78 | 79 | self.N = len(self.files) 80 | print('DataLoaderRaw found ', self.N, ' images') 81 | 82 | self.iterator = 0 83 | 84 | # Nasty 85 | self.dataset = self # to fix the bug in eval 86 | 87 | def get_batch(self, split, batch_size=None): 88 | batch_size = batch_size or self.batch_size 89 | 90 | # pick an index of the datapoint to load next 91 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') 92 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') 93 | max_index = self.N 94 | wrapped = False 95 | infos = [] 96 | 97 | for i in range(batch_size): 98 | ri = self.iterator 99 | ri_next = ri + 1 100 | if ri_next >= max_index: 101 | ri_next = 0 102 | wrapped = True 103 | # wrap back around 104 | self.iterator = ri_next 105 | 106 | img = skimage.io.imread(self.files[ri]) 107 | 108 | if len(img.shape) == 2: 109 | img = img[:,:,np.newaxis] 110 | img = np.concatenate((img, img, img), axis=2) 111 | 112 | img = img[:,:,:3].astype('float32')/255.0 113 | img = torch.from_numpy(img.transpose([2,0,1])).cuda() 114 | img = preprocess(img) 115 | with torch.no_grad(): 116 | tmp_fc, tmp_att = self.my_resnet(img) 117 | 118 | fc_batch[i] = tmp_fc.data.cpu().float().numpy() 119 | att_batch[i] = tmp_att.data.cpu().float().numpy() 120 | 121 | info_struct = {} 122 | info_struct['id'] = self.ids[ri] 123 | info_struct['file_path'] = self.files[ri] 124 | infos.append(info_struct) 125 | 126 | data = {} 127 | data['fc_feats'] = fc_batch 128 | data['att_feats'] = att_batch.reshape(batch_size, -1, 2048) 129 | data['labels'] = np.zeros([batch_size, 0]) 130 | data['masks'] = None 131 | data['att_masks'] = None 132 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} 133 | data['infos'] = infos 134 | 135 | data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor 136 | 137 | return data 138 | 139 | def reset_iterator(self, split): 140 | self.iterator = 0 141 | 142 | def get_vocab_size(self): 143 | return len(self.ix_to_word) 144 | 145 | def get_vocab(self): 146 | return self.ix_to_word 147 | -------------------------------------------------------------------------------- /captioning/data/pth_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | from lmdbdict import lmdbdict 8 | from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC 9 | import os 10 | import numpy as np 11 | import numpy.random as npr 12 | import random 13 | 14 | import torch 15 | import torch.utils.data as data 16 | 17 | import multiprocessing 18 | import six 19 | 20 | class HybridLoader: 21 | """ 22 | If db_path is a director, then use normal file loading 23 | If lmdb, then load from lmdb 24 | The loading method depend on extention. 25 | 26 | in_memory: if in_memory is True, we save all the features in memory 27 | For individual np(y|z)s, we don't need to do that because the system will do this for us. 28 | Should be useful for lmdb or h5. 29 | (Copied this idea from vilbert) 30 | """ 31 | def __init__(self, db_path, ext, in_memory=False): 32 | self.db_path = db_path 33 | self.ext = ext 34 | if self.ext == '.npy': 35 | self.loader = lambda x: np.load(six.BytesIO(x)) 36 | else: 37 | self.loader = lambda x: np.load(six.BytesIO(x))['feat'] 38 | if db_path.endswith('.lmdb'): 39 | self.db_type = 'lmdb' 40 | self.lmdb = lmdbdict(db_path, unsafe=True) 41 | self.lmdb._key_dumps = DUMPS_FUNC['ascii'] 42 | self.lmdb._value_loads = LOADS_FUNC['identity'] 43 | elif db_path.endswith('.pth'): # Assume a key,value dictionary 44 | self.db_type = 'pth' 45 | self.feat_file = torch.load(db_path) 46 | self.loader = lambda x: x 47 | print('HybridLoader: ext is ignored') 48 | elif db_path.endswith('h5'): 49 | self.db_type = 'h5' 50 | self.loader = lambda x: np.array(x).astype('float32') 51 | else: 52 | self.db_type = 'dir' 53 | 54 | self.in_memory = in_memory 55 | if self.in_memory: 56 | self.features = {} 57 | 58 | def get(self, key): 59 | 60 | if self.in_memory and key in self.features: 61 | # We save f_input because we want to save the 62 | # compressed bytes to save memory 63 | f_input = self.features[key] 64 | elif self.db_type == 'lmdb': 65 | f_input = self.lmdb[key] 66 | elif self.db_type == 'pth': 67 | f_input = self.feat_file[key] 68 | elif self.db_type == 'h5': 69 | f_input = h5py.File(self.db_path, 'r')[key] 70 | else: 71 | f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() 72 | 73 | if self.in_memory and key not in self.features: 74 | self.features[key] = f_input 75 | 76 | # load image 77 | feat = self.loader(f_input) 78 | 79 | return feat 80 | 81 | class CaptionDataset(data.Dataset): 82 | 83 | def get_vocab_size(self): 84 | return self.vocab_size 85 | 86 | def get_vocab(self): 87 | return self.ix_to_word 88 | 89 | def get_seq_length(self): 90 | return self.seq_length 91 | 92 | def __init__(self, opt): 93 | self.opt = opt 94 | self.seq_per_img = opt.seq_per_img 95 | 96 | # feature related options 97 | self.use_fc = getattr(opt, 'use_fc', True) 98 | self.use_att = getattr(opt, 'use_att', True) 99 | self.use_box = getattr(opt, 'use_box', 0) 100 | self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) 101 | self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) 102 | 103 | # load the json file which contains additional information about the dataset 104 | print('DataLoader loading json file: ', opt.input_json) 105 | self.info = json.load(open(self.opt.input_json)) 106 | if 'ix_to_word' in self.info: 107 | self.ix_to_word = self.info['ix_to_word'] 108 | self.vocab_size = len(self.ix_to_word) 109 | print('vocab size is ', self.vocab_size) 110 | 111 | # open the hdf5 file 112 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) 113 | """ 114 | Setting input_label_h5 to none is used when only doing generation. 115 | For example, when you need to test on coco test set. 116 | """ 117 | if self.opt.input_label_h5 != 'none': 118 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') 119 | # load in the sequence data 120 | seq_size = self.h5_label_file['labels'].shape 121 | self.label = self.h5_label_file['labels'][:] 122 | self.seq_length = seq_size[1] 123 | print('max sequence length in data is', self.seq_length) 124 | # load the pointers in full to RAM (should be small enough) 125 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 126 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 127 | else: 128 | self.seq_length = 1 129 | 130 | self.data_in_memory = getattr(opt, 'data_in_memory', False) 131 | self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) 132 | self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) 133 | self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) 134 | 135 | self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] 136 | print('read %d image features' %(self.num_images)) 137 | 138 | # separate out indexes for each of the provided splits 139 | self.split_ix = {'train': [], 'val': [], 'test': []} 140 | for ix in range(len(self.info['images'])): 141 | img = self.info['images'][ix] 142 | if not 'split' in img: 143 | self.split_ix['train'].append(ix) 144 | self.split_ix['val'].append(ix) 145 | self.split_ix['test'].append(ix) 146 | elif img['split'] == 'train': 147 | self.split_ix['train'].append(ix) 148 | elif img['split'] == 'val': 149 | self.split_ix['val'].append(ix) 150 | elif img['split'] == 'test': 151 | self.split_ix['test'].append(ix) 152 | elif opt.train_only == 0: # restval 153 | self.split_ix['train'].append(ix) 154 | 155 | print('assigned %d images to split train' %len(self.split_ix['train'])) 156 | print('assigned %d images to split val' %len(self.split_ix['val'])) 157 | print('assigned %d images to split test' %len(self.split_ix['test'])) 158 | 159 | def get_captions(self, ix, seq_per_img): 160 | # fetch the sequence labels 161 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 162 | ix2 = self.label_end_ix[ix] - 1 163 | ncap = ix2 - ix1 + 1 # number of captions available for this image 164 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 165 | 166 | if ncap < seq_per_img: 167 | # we need to subsample (with replacement) 168 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') 169 | for q in range(seq_per_img): 170 | ixl = random.randint(ix1,ix2) 171 | seq[q, :] = self.label[ixl, :self.seq_length] 172 | else: 173 | ixl = random.randint(ix1, ix2 - seq_per_img + 1) 174 | seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] 175 | 176 | return seq 177 | 178 | def collate_func(self, batch): 179 | seq_per_img = self.seq_per_img 180 | 181 | fc_batch = [] 182 | att_batch = [] 183 | label_batch = [] 184 | 185 | wrapped = False 186 | 187 | infos = [] 188 | gts = [] 189 | 190 | for sample in batch: 191 | # fetch image 192 | tmp_fc, tmp_att, tmp_seq, \ 193 | ix = sample 194 | 195 | fc_batch.append(tmp_fc) 196 | att_batch.append(tmp_att) 197 | 198 | tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') 199 | if hasattr(self, 'h5_label_file'): 200 | # if there is ground truth 201 | tmp_label[:, 1 : self.seq_length + 1] = tmp_seq 202 | label_batch.append(tmp_label) 203 | 204 | # Used for reward evaluation 205 | if hasattr(self, 'h5_label_file'): 206 | # if there is ground truth 207 | gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) 208 | else: 209 | gts.append([]) 210 | 211 | # record associated info as well 212 | info_dict = {} 213 | info_dict['ix'] = ix 214 | info_dict['id'] = self.info['images'][ix]['id'] 215 | info_dict['file_path'] = self.info['images'][ix].get('file_path', '') 216 | infos.append(info_dict) 217 | 218 | # #sort by att_feat length 219 | # fc_batch, att_batch, label_batch, gts, infos = \ 220 | # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) 221 | fc_batch, att_batch, label_batch, gts, infos = \ 222 | zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) 223 | data = {} 224 | data['fc_feats'] = np.stack(fc_batch) 225 | # merge att_feats 226 | max_att_len = max([_.shape[0] for _ in att_batch]) 227 | data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') 228 | for i in range(len(att_batch)): 229 | data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] 230 | data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') 231 | for i in range(len(att_batch)): 232 | data['att_masks'][i, :att_batch[i].shape[0]] = 1 233 | # set att_masks to None if attention features have same length 234 | if data['att_masks'].sum() == data['att_masks'].size: 235 | data['att_masks'] = None 236 | 237 | data['labels'] = np.vstack(label_batch) 238 | # generate mask 239 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) 240 | mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') 241 | for ix, row in enumerate(mask_batch): 242 | row[:nonzeros[ix]] = 1 243 | data['masks'] = mask_batch 244 | data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) 245 | data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) 246 | 247 | data['gts'] = gts # all ground truth captions of each images 248 | data['infos'] = infos 249 | 250 | data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor 251 | 252 | return data 253 | 254 | def __getitem__(self, ix): 255 | """This function returns a tuple that is further passed to collate_fn 256 | """ 257 | if self.use_att: 258 | att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) 259 | # Reshape to K x C 260 | att_feat = att_feat.reshape(-1, att_feat.shape[-1]) 261 | if self.norm_att_feat: 262 | att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) 263 | if self.use_box: 264 | box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) 265 | # devided by image width and height 266 | x1,y1,x2,y2 = np.hsplit(box_feat, 4) 267 | h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] 268 | box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? 269 | if self.norm_box_feat: 270 | box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) 271 | att_feat = np.hstack([att_feat, box_feat]) 272 | # sort the features by the size of boxes 273 | att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) 274 | else: 275 | att_feat = np.zeros((0,0), dtype='float32') 276 | if self.use_fc: 277 | try: 278 | fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) 279 | except: 280 | # Use average of attention when there is no fc provided (For bottomup feature) 281 | fc_feat = att_feat.mean(0) 282 | else: 283 | fc_feat = np.zeros((0), dtype='float32') 284 | if hasattr(self, 'h5_label_file'): 285 | seq = self.get_captions(ix, self.seq_per_img) 286 | else: 287 | seq = None 288 | return (fc_feat, 289 | att_feat, seq, 290 | ix) 291 | 292 | def __len__(self): 293 | return len(self.info['images']) -------------------------------------------------------------------------------- /captioning/models/AoAModel.py: -------------------------------------------------------------------------------- 1 | # Implementation for paper 'Attention on Attention for Image Captioning' 2 | # https://arxiv.org/abs/1908.06954 3 | 4 | # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .AttModel import pack_wrapper, AttModel, Attention 15 | from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward 16 | 17 | class MultiHeadedDotAttention(nn.Module): 18 | def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): 19 | super(MultiHeadedDotAttention, self).__init__() 20 | assert d_model * scale % h == 0 21 | # We assume d_v always equals d_k 22 | self.d_k = d_model * scale // h 23 | self.h = h 24 | 25 | # Do we need to do linear projections on K and V? 26 | self.project_k_v = project_k_v 27 | 28 | # normalize the query? 29 | if norm_q: 30 | self.norm = LayerNorm(d_model) 31 | else: 32 | self.norm = lambda x:x 33 | self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) 34 | 35 | # output linear layer after the multi-head attention? 36 | self.output_layer = nn.Linear(d_model * scale, d_model) 37 | 38 | # apply aoa after attention? 39 | self.use_aoa = do_aoa 40 | if self.use_aoa: 41 | self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) 42 | # dropout to the input of AoA layer 43 | if dropout_aoa > 0: 44 | self.dropout_aoa = nn.Dropout(p=dropout_aoa) 45 | else: 46 | self.dropout_aoa = lambda x:x 47 | 48 | if self.use_aoa or not use_output_layer: 49 | # AoA doesn't need the output linear layer 50 | del self.output_layer 51 | self.output_layer = lambda x:x 52 | 53 | self.attn = None 54 | self.dropout = nn.Dropout(p=dropout) 55 | 56 | def forward(self, query, value, key, mask=None): 57 | if mask is not None: 58 | if len(mask.size()) == 2: 59 | mask = mask.unsqueeze(-2) 60 | # Same mask applied to all h heads. 61 | mask = mask.unsqueeze(1) 62 | 63 | single_query = 0 64 | if len(query.size()) == 2: 65 | single_query = 1 66 | query = query.unsqueeze(1) 67 | 68 | nbatches = query.size(0) 69 | 70 | query = self.norm(query) 71 | 72 | # Do all the linear projections in batch from d_model => h x d_k 73 | if self.project_k_v == 0: 74 | query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 75 | key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 76 | value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 77 | else: 78 | query_, key_, value_ = \ 79 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 80 | for l, x in zip(self.linears, (query, key, value))] 81 | 82 | # Apply attention on all the projected vectors in batch. 83 | x, self.attn = attention(query_, key_, value_, mask=mask, 84 | dropout=self.dropout) 85 | 86 | # "Concat" using a view 87 | x = x.transpose(1, 2).contiguous() \ 88 | .view(nbatches, -1, self.h * self.d_k) 89 | 90 | if self.use_aoa: 91 | # Apply AoA 92 | x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) 93 | x = self.output_layer(x) 94 | 95 | if single_query: 96 | query = query.squeeze(1) 97 | x = x.squeeze(1) 98 | return x 99 | 100 | class AoA_Refiner_Layer(nn.Module): 101 | def __init__(self, size, self_attn, feed_forward, dropout): 102 | super(AoA_Refiner_Layer, self).__init__() 103 | self.self_attn = self_attn 104 | self.feed_forward = feed_forward 105 | self.use_ff = 0 106 | if self.feed_forward is not None: 107 | self.use_ff = 1 108 | self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) 109 | self.size = size 110 | 111 | def forward(self, x, mask): 112 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 113 | return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x 114 | 115 | class AoA_Refiner_Core(nn.Module): 116 | def __init__(self, opt): 117 | super(AoA_Refiner_Core, self).__init__() 118 | attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) 119 | layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) 120 | self.layers = clones(layer, 6) 121 | self.norm = LayerNorm(layer.size) 122 | 123 | def forward(self, x, mask): 124 | for layer in self.layers: 125 | x = layer(x, mask) 126 | return self.norm(x) 127 | 128 | class AoA_Decoder_Core(nn.Module): 129 | def __init__(self, opt): 130 | super(AoA_Decoder_Core, self).__init__() 131 | self.drop_prob_lm = opt.drop_prob_lm 132 | self.d_model = opt.rnn_size 133 | self.use_multi_head = opt.use_multi_head 134 | self.multi_head_scale = opt.multi_head_scale 135 | self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) 136 | self.out_res = getattr(opt, 'out_res', 0) 137 | self.decoder_type = getattr(opt, 'decoder_type', 'AoA') 138 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 139 | self.out_drop = nn.Dropout(self.drop_prob_lm) 140 | 141 | if self.decoder_type == 'AoA': 142 | # AoA layer 143 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) 144 | elif self.decoder_type == 'LSTM': 145 | # LSTM layer 146 | self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) 147 | else: 148 | # Base linear layer 149 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) 150 | 151 | # if opt.use_multi_head == 1: # TODO, not implemented for now 152 | # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) 153 | if opt.use_multi_head == 2: 154 | self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) 155 | else: 156 | self.attention = Attention(opt) 157 | 158 | if self.use_ctx_drop: 159 | self.ctx_drop = nn.Dropout(self.drop_prob_lm) 160 | else: 161 | self.ctx_drop = lambda x :x 162 | 163 | def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): 164 | # state[0][1] is the context vector at the last step 165 | h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) 166 | 167 | if self.use_multi_head == 2: 168 | att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) 169 | else: 170 | att = self.attention(h_att, att_feats, p_att_feats, att_masks) 171 | 172 | ctx_input = torch.cat([att, h_att], 1) 173 | if self.decoder_type == 'LSTM': 174 | output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) 175 | state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) 176 | else: 177 | output = self.att2ctx(ctx_input) 178 | # save the context vector to state[0][1] 179 | state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) 180 | 181 | if self.out_res: 182 | # add residual connection 183 | output = output + h_att 184 | 185 | output = self.out_drop(output) 186 | return output, state 187 | 188 | class AoAModel(AttModel): 189 | def __init__(self, opt): 190 | super(AoAModel, self).__init__(opt) 191 | self.num_layers = 2 192 | # mean pooling 193 | self.use_mean_feats = getattr(opt, 'mean_feats', 1) 194 | if opt.use_multi_head == 2: 195 | del self.ctx2att 196 | self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) 197 | 198 | if self.use_mean_feats: 199 | del self.fc_embed 200 | if opt.refine: 201 | self.refiner = AoA_Refiner_Core(opt) 202 | else: 203 | self.refiner = lambda x,y : x 204 | self.core = AoA_Decoder_Core(opt) 205 | 206 | 207 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 208 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 209 | 210 | # embed att feats 211 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 212 | att_feats = self.refiner(att_feats, att_masks) 213 | 214 | if self.use_mean_feats: 215 | # meaning pooling 216 | if att_masks is None: 217 | mean_feats = torch.mean(att_feats, dim=1) 218 | else: 219 | mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) 220 | else: 221 | mean_feats = self.fc_embed(fc_feats) 222 | 223 | # Project the attention feats first to reduce memory and computation. 224 | p_att_feats = self.ctx2att(att_feats) 225 | 226 | return mean_feats, att_feats, p_att_feats, att_masks -------------------------------------------------------------------------------- /captioning/models/AttEnsemble.py: -------------------------------------------------------------------------------- 1 | # This file is the implementation for ensemble evaluation. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import * 12 | 13 | from .CaptionModel import CaptionModel 14 | from .AttModel import pack_wrapper, AttModel 15 | 16 | class AttEnsemble(AttModel): 17 | def __init__(self, models, weights=None): 18 | CaptionModel.__init__(self) 19 | # super(AttEnsemble, self).__init__() 20 | 21 | self.models = nn.ModuleList(models) 22 | self.vocab_size = models[0].vocab_size 23 | self.seq_length = models[0].seq_length 24 | self.bad_endings_ix = models[0].bad_endings_ix 25 | self.ss_prob = 0 26 | weights = weights or [1.0] * len(self.models) 27 | self.register_buffer('weights', torch.tensor(weights)) 28 | 29 | def init_hidden(self, batch_size): 30 | state = [m.init_hidden(batch_size) for m in self.models] 31 | return self.pack_state(state) 32 | 33 | def pack_state(self, state): 34 | self.state_lengths = [len(_) for _ in state] 35 | return sum([list(_) for _ in state], []) 36 | 37 | def unpack_state(self, state): 38 | out = [] 39 | for l in self.state_lengths: 40 | out.append(state[:l]) 41 | state = state[l:] 42 | return out 43 | 44 | def embed(self, it): 45 | return [m.embed(it) for m in self.models] 46 | 47 | def core(self, *args): 48 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) 49 | 50 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1): 51 | # 'it' contains a word index 52 | xt = self.embed(it) 53 | 54 | state = self.unpack_state(state) 55 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) 56 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() 57 | 58 | return logprobs, self.pack_state(state) 59 | 60 | def _prepare_feature(self, *args): 61 | return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) 62 | 63 | def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 64 | beam_size = opt.get('beam_size', 10) 65 | batch_size = fc_feats.size(0) 66 | 67 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 68 | 69 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 70 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 71 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) 72 | # lets process every image independently for now, for simplicity 73 | 74 | self.done_beams = [[] for _ in range(batch_size)] 75 | for k in range(batch_size): 76 | state = self.init_hidden(beam_size) 77 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] 78 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 79 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 80 | tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] 81 | 82 | it = fc_feats[0].data.new(beam_size).long().zero_() 83 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) 84 | 85 | self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) 86 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 87 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 88 | # return the samples and their log likelihoods 89 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 90 | # return the samples and their log likelihoods 91 | -------------------------------------------------------------------------------- /captioning/models/BertCapModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | BertCapModel is using huggingface transformer bert model as seq2seq model. 3 | 4 | The result is not as goog as original transformer. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import copy 16 | import math 17 | import numpy as np 18 | 19 | from .CaptionModel import CaptionModel 20 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 21 | try: 22 | from transformers import BertModel, BertConfig 23 | except: 24 | print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers') 25 | from .TransformerModel import subsequent_mask, TransformerModel, Generator 26 | 27 | class EncoderDecoder(nn.Module): 28 | """ 29 | A standard Encoder-Decoder architecture. Base for this and many 30 | other models. 31 | """ 32 | def __init__(self, encoder, decoder, generator): 33 | super(EncoderDecoder, self).__init__() 34 | self.encoder = encoder 35 | self.decoder = decoder 36 | self.generator = generator 37 | 38 | def forward(self, src, tgt, src_mask, tgt_mask): 39 | "Take in and process masked src and target sequences." 40 | return self.decode(self.encode(src, src_mask), src_mask, 41 | tgt, tgt_mask) 42 | 43 | def encode(self, src, src_mask): 44 | return self.encoder(inputs_embeds=src, 45 | attention_mask=src_mask)[0] 46 | 47 | def decode(self, memory, src_mask, tgt, tgt_mask): 48 | return self.decoder(input_ids=tgt, 49 | attention_mask=tgt_mask, 50 | encoder_hidden_states=memory, 51 | encoder_attention_mask=src_mask)[0] 52 | 53 | 54 | class BertCapModel(TransformerModel): 55 | 56 | def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, 57 | d_model=512, d_ff=2048, h=8, dropout=0.1): 58 | "Helper: Construct a model from hyperparameters." 59 | enc_config = BertConfig(vocab_size=1, 60 | hidden_size=d_model, 61 | num_hidden_layers=N_enc, 62 | num_attention_heads=h, 63 | intermediate_size=d_ff, 64 | hidden_dropout_prob=dropout, 65 | attention_probs_dropout_prob=dropout, 66 | max_position_embeddings=1, 67 | type_vocab_size=1) 68 | dec_config = BertConfig(vocab_size=tgt_vocab, 69 | hidden_size=d_model, 70 | num_hidden_layers=N_dec, 71 | num_attention_heads=h, 72 | intermediate_size=d_ff, 73 | hidden_dropout_prob=dropout, 74 | attention_probs_dropout_prob=dropout, 75 | max_position_embeddings=17, 76 | type_vocab_size=1, 77 | is_decoder=True) 78 | encoder = BertModel(enc_config) 79 | def return_embeds(*args, **kwargs): 80 | return kwargs['inputs_embeds'] 81 | del encoder.embeddings; encoder.embeddings = return_embeds 82 | decoder = BertModel(dec_config) 83 | model = EncoderDecoder( 84 | encoder, 85 | decoder, 86 | Generator(d_model, tgt_vocab)) 87 | return model 88 | 89 | def __init__(self, opt): 90 | super(BertCapModel, self).__init__(opt) 91 | 92 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 93 | """ 94 | state = [ys.unsqueeze(0)] 95 | """ 96 | if len(state) == 0: 97 | ys = it.unsqueeze(1) 98 | else: 99 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 100 | out = self.model.decode(memory, mask, 101 | ys, 102 | subsequent_mask(ys.size(1)) 103 | .to(memory.device)) 104 | return out[:, -1], [ys.unsqueeze(0)] 105 | -------------------------------------------------------------------------------- /captioning/models/FCModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | from . import utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class LSTMCore(nn.Module): 14 | def __init__(self, opt): 15 | super(LSTMCore, self).__init__() 16 | self.input_encoding_size = opt.input_encoding_size 17 | self.rnn_size = opt.rnn_size 18 | self.drop_prob_lm = opt.drop_prob_lm 19 | 20 | # Build a LSTM 21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 23 | self.dropout = nn.Dropout(self.drop_prob_lm) 24 | 25 | def forward(self, xt, state): 26 | 27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 29 | sigmoid_chunk = torch.sigmoid(sigmoid_chunk) 30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 33 | 34 | in_transform = torch.max(\ 35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), 36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) 37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 38 | next_h = out_gate * torch.tanh(next_c) 39 | 40 | output = self.dropout(next_h) 41 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 42 | return output, state 43 | 44 | class FCModel(CaptionModel): 45 | def __init__(self, opt): 46 | super(FCModel, self).__init__() 47 | self.vocab_size = opt.vocab_size 48 | self.input_encoding_size = opt.input_encoding_size 49 | self.rnn_type = opt.rnn_type 50 | self.rnn_size = opt.rnn_size 51 | self.num_layers = opt.num_layers 52 | self.drop_prob_lm = opt.drop_prob_lm 53 | self.seq_length = opt.seq_length 54 | self.fc_feat_size = opt.fc_feat_size 55 | 56 | self.ss_prob = 0.0 # Schedule sampling probability 57 | 58 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 59 | self.core = LSTMCore(opt) 60 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 61 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 62 | 63 | self.init_weights() 64 | 65 | def init_weights(self): 66 | initrange = 0.1 67 | self.embed.weight.data.uniform_(-initrange, initrange) 68 | self.logit.bias.data.fill_(0) 69 | self.logit.weight.data.uniform_(-initrange, initrange) 70 | 71 | def init_hidden(self, bsz): 72 | weight = self.logit.weight 73 | if self.rnn_type == 'lstm': 74 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 75 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 76 | else: 77 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 78 | 79 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 80 | batch_size = fc_feats.size(0) 81 | seq_per_img = seq.shape[0] // batch_size 82 | state = self.init_hidden(batch_size*seq_per_img) 83 | outputs = [] 84 | 85 | if seq_per_img > 1: 86 | fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) 87 | 88 | for i in range(seq.size(1) + 1): 89 | if i == 0: 90 | xt = self.img_embed(fc_feats) 91 | else: 92 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 93 | sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) 94 | sample_mask = sample_prob < self.ss_prob 95 | if sample_mask.sum() == 0: 96 | it = seq[:, i-1].clone() 97 | else: 98 | sample_ind = sample_mask.nonzero().view(-1) 99 | it = seq[:, i-1].data.clone() 100 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 101 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 102 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 103 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 104 | else: 105 | it = seq[:, i-1].clone() 106 | # break if all the sequences end 107 | if i >= 2 and seq[:, i-1].sum() == 0: 108 | break 109 | xt = self.embed(it) 110 | 111 | output, state = self.core(xt, state) 112 | output = F.log_softmax(self.logit(output), dim=1) 113 | outputs.append(output) 114 | 115 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 116 | 117 | def get_logprobs_state(self, it, state): 118 | # 'it' is contains a word index 119 | xt = self.embed(it) 120 | 121 | output, state = self.core(xt, state) 122 | logprobs = F.log_softmax(self.logit(output), dim=1) 123 | 124 | return logprobs, state 125 | 126 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 127 | beam_size = opt.get('beam_size', 10) 128 | batch_size = fc_feats.size(0) 129 | 130 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 131 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 132 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) 133 | # lets process every image independently for now, for simplicity 134 | 135 | self.done_beams = [[] for _ in range(batch_size)] 136 | for k in range(batch_size): 137 | state = self.init_hidden(beam_size) 138 | for t in range(2): 139 | if t == 0: 140 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 141 | elif t == 1: # input 142 | it = fc_feats.data.new(beam_size).long().zero_() 143 | xt = self.embed(it) 144 | 145 | output, state = self.core(xt, state) 146 | logprobs = F.log_softmax(self.logit(output), dim=1) 147 | 148 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 149 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 150 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 151 | # return the samples and their log likelihoods 152 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 153 | 154 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 155 | sample_method = opt.get('sample_method', 'greedy') 156 | beam_size = opt.get('beam_size', 1) 157 | temperature = opt.get('temperature', 1.0) 158 | if beam_size > 1 and sample_method in ['greedy', 'beam_search']: 159 | return self._sample_beam(fc_feats, att_feats, opt) 160 | 161 | batch_size = fc_feats.size(0) 162 | state = self.init_hidden(batch_size) 163 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 164 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1) 165 | for t in range(self.seq_length + 2): 166 | if t == 0: 167 | xt = self.img_embed(fc_feats) 168 | else: 169 | if t == 1: # input 170 | it = fc_feats.data.new(batch_size).long().zero_() 171 | xt = self.embed(it) 172 | 173 | output, state = self.core(xt, state) 174 | logprobs = F.log_softmax(self.logit(output), dim=1) 175 | 176 | # sample the next_word 177 | if t == self.seq_length + 1: # skip if we achieve maximum length 178 | break 179 | if sample_method == 'greedy': 180 | sampleLogprobs, it = torch.max(logprobs.data, 1) 181 | it = it.view(-1).long() 182 | else: 183 | if temperature == 1.0: 184 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 185 | else: 186 | # scale logprobs by temperature 187 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 188 | it = torch.multinomial(prob_prev, 1).to(logprobs.device) 189 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 190 | it = it.view(-1).long() # and flatten indices for downstream processing 191 | 192 | if t >= 1: 193 | # stop when all finished 194 | if t == 1: 195 | unfinished = it > 0 196 | else: 197 | unfinished = unfinished & (it > 0) 198 | it = it * unfinished.type_as(it) 199 | seq[:,t-1] = it #seq[t] the input of t+2 time step 200 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 201 | if unfinished.sum() == 0: 202 | break 203 | 204 | return seq, seqLogprobs 205 | -------------------------------------------------------------------------------- /captioning/models/M2Transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) 3 | 4 | pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git 5 | 6 | Note: 7 | Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import copy 19 | import math 20 | import numpy as np 21 | 22 | from .CaptionModel import CaptionModel 23 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 24 | 25 | try: 26 | from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory 27 | except: 28 | print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') 29 | from .TransformerModel import subsequent_mask, TransformerModel 30 | 31 | 32 | class M2TransformerModel(TransformerModel): 33 | 34 | def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, 35 | d_model=512, d_ff=2048, h=8, dropout=0.1): 36 | "Helper: Construct a model from hyperparameters." 37 | encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, 38 | attention_module_kwargs={'m': 40}) 39 | # Another implementation is to use MultiLevelEncoder + att_embed 40 | decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; 41 | model = Transformer(0, encoder, decoder) # 0 is bos 42 | return model 43 | 44 | def __init__(self, opt): 45 | super(M2TransformerModel, self).__init__(opt) 46 | delattr(self, 'att_embed') 47 | self.att_embed = lambda x: x # The visual embed is in the MAEncoder 48 | # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? 49 | # Also the attention mask seems wrong in MAEncoder too...intersting 50 | 51 | def logit(self, x): # unsafe way 52 | return x # M2transformer always output logsoftmax 53 | 54 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 55 | 56 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 57 | memory, att_masks = self.model.encoder(att_feats) 58 | 59 | return fc_feats[...,:0], att_feats[...,:0], memory, att_masks 60 | 61 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 62 | if seq.ndim == 3: # B * seq_per_img * seq_len 63 | seq = seq.reshape(-1, seq.shape[2]) 64 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 65 | 66 | seq = seq.clone() 67 | seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) 68 | outputs = self.model(att_feats, seq) 69 | 70 | return outputs 71 | 72 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 73 | """ 74 | state = [ys.unsqueeze(0)] 75 | """ 76 | if len(state) == 0: 77 | ys = it.unsqueeze(1) 78 | else: 79 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 80 | out = self.model.decoder(ys, memory, mask) 81 | return out[:, -1], [ys.unsqueeze(0)] 82 | 83 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 84 | beam_size = opt.get('beam_size', 10) 85 | group_size = opt.get('group_size', 1) 86 | sample_n = opt.get('sample_n', 10) 87 | assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' 88 | 89 | att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) 90 | seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, 91 | beam_size, return_probs=True, out_size=beam_size) 92 | seq = seq.reshape(-1, *seq.shape[2:]) 93 | seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) 94 | 95 | # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): 96 | # import pudb;pu.db 97 | # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) 98 | return seq, seqLogprobs -------------------------------------------------------------------------------- /captioning/models/ShowTellModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | from . import utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class ShowTellModel(CaptionModel): 14 | def __init__(self, opt): 15 | super(ShowTellModel, self).__init__() 16 | self.vocab_size = opt.vocab_size 17 | self.input_encoding_size = opt.input_encoding_size 18 | self.rnn_type = opt.rnn_type 19 | self.rnn_size = opt.rnn_size 20 | self.num_layers = opt.num_layers 21 | self.drop_prob_lm = opt.drop_prob_lm 22 | self.seq_length = opt.seq_length 23 | self.fc_feat_size = opt.fc_feat_size 24 | 25 | self.ss_prob = 0.0 # Schedule sampling probability 26 | 27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 31 | self.dropout = nn.Dropout(self.drop_prob_lm) 32 | 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | initrange = 0.1 37 | self.embed.weight.data.uniform_(-initrange, initrange) 38 | self.logit.bias.data.fill_(0) 39 | self.logit.weight.data.uniform_(-initrange, initrange) 40 | 41 | def init_hidden(self, bsz): 42 | weight = self.logit.weight 43 | if self.rnn_type == 'lstm': 44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 46 | else: 47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 48 | 49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 50 | batch_size = fc_feats.size(0) 51 | seq_per_img = seq.shape[0] // batch_size 52 | state = self.init_hidden(batch_size*seq_per_img) 53 | outputs = [] 54 | 55 | if seq_per_img > 1: 56 | fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) 57 | 58 | for i in range(seq.size(1) + 1): 59 | if i == 0: 60 | xt = self.img_embed(fc_feats) 61 | else: 62 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 63 | sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) 64 | sample_mask = sample_prob < self.ss_prob 65 | if sample_mask.sum() == 0: 66 | it = seq[:, i-1].clone() 67 | else: 68 | sample_ind = sample_mask.nonzero().view(-1) 69 | it = seq[:, i-1].data.clone() 70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 74 | else: 75 | it = seq[:, i-1].clone() 76 | # break if all the sequences end 77 | if i >= 2 and seq[:, i-1].data.sum() == 0: 78 | break 79 | xt = self.embed(it) 80 | 81 | output, state = self.core(xt.unsqueeze(0), state) 82 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 83 | outputs.append(output) 84 | 85 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 86 | 87 | def get_logprobs_state(self, it, state): 88 | # 'it' contains a word index 89 | xt = self.embed(it) 90 | 91 | output, state = self.core(xt.unsqueeze(0), state) 92 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 93 | 94 | return logprobs, state 95 | 96 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 97 | beam_size = opt.get('beam_size', 10) 98 | batch_size = fc_feats.size(0) 99 | 100 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 101 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 102 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 103 | # lets process every image independently for now, for simplicity 104 | 105 | self.done_beams = [[] for _ in range(batch_size)] 106 | for k in range(batch_size): 107 | state = self.init_hidden(beam_size) 108 | for t in range(2): 109 | if t == 0: 110 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 111 | elif t == 1: # input 112 | it = fc_feats.data.new(beam_size).long().zero_() 113 | xt = self.embed(it) 114 | 115 | output, state = self.core(xt.unsqueeze(0), state) 116 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 117 | 118 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 119 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 120 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 121 | # return the samples and their log likelihoods 122 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 123 | 124 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 125 | sample_method = opt.get('sample_method', 'greedy') 126 | beam_size = opt.get('beam_size', 1) 127 | temperature = opt.get('temperature', 1.0) 128 | if beam_size > 1 and sample_method in ['greedy', 'beam_search']: 129 | return self.sample_beam(fc_feats, att_feats, opt) 130 | 131 | batch_size = fc_feats.size(0) 132 | state = self.init_hidden(batch_size) 133 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 134 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 135 | for t in range(self.seq_length + 2): 136 | if t == 0: 137 | xt = self.img_embed(fc_feats) 138 | else: 139 | if t == 1: # input 140 | it = fc_feats.data.new(batch_size).long().zero_() 141 | xt = self.embed(it) 142 | 143 | output, state = self.core(xt.unsqueeze(0), state) 144 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 145 | 146 | # sample the next word 147 | if t == self.seq_length + 1: # skip if we achieve maximum length 148 | break 149 | if sample_method == 'greedy': 150 | sampleLogprobs, it = torch.max(logprobs.data, 1) 151 | it = it.view(-1).long() 152 | else: 153 | if temperature == 1.0: 154 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 155 | else: 156 | # scale logprobs by temperature 157 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 158 | it = torch.multinomial(prob_prev, 1).to(logprobs.device) 159 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 160 | it = it.view(-1).long() # and flatten indices for downstream processing 161 | 162 | if t >= 1: 163 | # stop when all finished 164 | if t == 1: 165 | unfinished = it > 0 166 | else: 167 | unfinished = unfinished & (it > 0) 168 | it = it * unfinished.type_as(it) 169 | seq[:,t-1] = it #seq[t] the input of t+2 time step 170 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 171 | if unfinished.sum() == 0: 172 | break 173 | 174 | return seq, seqLogprobs -------------------------------------------------------------------------------- /captioning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .ShowTellModel import ShowTellModel 12 | from .FCModel import FCModel 13 | from .AttModel import * 14 | from .TransformerModel import TransformerModel 15 | from .cachedTransformer import TransformerModel as cachedTransformer 16 | from .BertCapModel import BertCapModel 17 | from .M2Transformer import M2TransformerModel 18 | from .AoAModel import AoAModel 19 | 20 | def setup(opt): 21 | if opt.caption_model in ['fc', 'show_tell']: 22 | print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) 23 | if opt.caption_model == 'fc': 24 | print('Use newfc instead of fc') 25 | if opt.caption_model == 'fc': 26 | model = FCModel(opt) 27 | elif opt.caption_model == 'language_model': 28 | model = LMModel(opt) 29 | elif opt.caption_model == 'newfc': 30 | model = NewFCModel(opt) 31 | elif opt.caption_model == 'show_tell': 32 | model = ShowTellModel(opt) 33 | # Att2in model in self-critical 34 | elif opt.caption_model == 'att2in': 35 | model = Att2inModel(opt) 36 | # Att2in model with two-layer MLP img embedding and word embedding 37 | elif opt.caption_model == 'att2in2': 38 | model = Att2in2Model(opt) 39 | elif opt.caption_model == 'att2all2': 40 | print('Warning: this is not a correct implementation of the att2all model in the original paper.') 41 | model = Att2all2Model(opt) 42 | # Adaptive Attention model from Knowing when to look 43 | elif opt.caption_model == 'adaatt': 44 | model = AdaAttModel(opt) 45 | # Adaptive Attention with maxout lstm 46 | elif opt.caption_model == 'adaattmo': 47 | model = AdaAttMOModel(opt) 48 | # Top-down attention model 49 | elif opt.caption_model in ['topdown', 'updown']: 50 | model = UpDownModel(opt) 51 | # StackAtt 52 | elif opt.caption_model == 'stackatt': 53 | model = StackAttModel(opt) 54 | # DenseAtt 55 | elif opt.caption_model == 'denseatt': 56 | model = DenseAttModel(opt) 57 | # Transformer 58 | elif opt.caption_model == 'transformer': 59 | if getattr(opt, 'cached_transformer', False): 60 | model = cachedTransformer(opt) 61 | else: 62 | model = TransformerModel(opt) 63 | # AoANet 64 | elif opt.caption_model == 'aoa': 65 | model = AoAModel(opt) 66 | elif opt.caption_model == 'bert': 67 | model = BertCapModel(opt) 68 | elif opt.caption_model == 'm2transformer': 69 | model = M2TransformerModel(opt) 70 | else: 71 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /captioning/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def repeat_tensors(n, x): 4 | """ 5 | For a tensor of size Bx..., we repeat it n times, and make it Bnx... 6 | For collections, do nested repeat 7 | """ 8 | if torch.is_tensor(x): 9 | x = x.unsqueeze(1) # Bx1x... 10 | x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... 11 | x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... 12 | elif type(x) is list or type(x) is tuple: 13 | x = [repeat_tensors(n, _) for _ in x] 14 | return x 15 | 16 | 17 | def split_tensors(n, x): 18 | if torch.is_tensor(x): 19 | assert x.shape[0] % n == 0 20 | x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) 21 | elif type(x) is list or type(x) is tuple: 22 | x = [split_tensors(n, _) for _ in x] 23 | elif x is None: 24 | x = [None] * n 25 | return x -------------------------------------------------------------------------------- /captioning/modules/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import losses 3 | from ..utils.rewards import init_scorer, get_self_critical_reward 4 | 5 | class LossWrapper(torch.nn.Module): 6 | def __init__(self, model, opt): 7 | super(LossWrapper, self).__init__() 8 | self.opt = opt 9 | self.model = model 10 | if opt.label_smoothing > 0: 11 | self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing) 12 | else: 13 | self.crit = losses.LanguageModelCriterion() 14 | self.rl_crit = losses.RewardCriterion() 15 | self.struc_crit = losses.StructureLosses(opt) 16 | self.ppo_crit = losses.PPOLoss(opt, model) 17 | 18 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, 19 | sc_flag, struc_flag, drop_worst_flag): 20 | opt = self.opt 21 | 22 | out = {} 23 | 24 | reduction = 'none' if drop_worst_flag else 'mean' 25 | if struc_flag: 26 | if opt.structure_loss_weight < 1: 27 | lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:], reduction=reduction) 28 | else: 29 | lm_loss = torch.tensor(0).type_as(fc_feats) 30 | if opt.structure_loss_weight > 0: 31 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, 32 | opt={'sample_method':opt.train_sample_method, 33 | 'beam_size':opt.train_beam_size, 34 | 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ 35 | or not 'margin' in opt.structure_loss_type, 36 | 'sample_n': opt.train_sample_n}, 37 | mode='sample') 38 | gts = [gts[_] for _ in gt_indices.tolist()] 39 | if opt.use_ppo: 40 | struc_loss = self.ppo_crit(sample_logprobs, gen_result, gts, fc_feats, att_feats, att_masks, reduction=reduction) 41 | else: 42 | struc_loss = self.struc_crit(sample_logprobs, gen_result, gts, reduction=reduction) 43 | else: 44 | struc_loss = {'loss': torch.tensor(0).type_as(fc_feats), 45 | 'reward': torch.tensor(0).type_as(fc_feats)} 46 | loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss'] 47 | out['lm_loss'] = lm_loss 48 | out['struc_loss'] = struc_loss['loss'] 49 | out['reward'] = struc_loss['reward'] 50 | if opt.use_ppo: 51 | out['pg_loss'] = struc_loss['pg_loss'] 52 | out['kl_loss'] = struc_loss['kl_loss'] 53 | out['clipfrac'] = struc_loss['clipfrac'] 54 | elif not sc_flag: 55 | loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:], reduction=reduction) 56 | else: 57 | self.model.eval() 58 | with torch.no_grad(): 59 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, 60 | mode='sample', 61 | opt={'sample_method': opt.sc_sample_method, 62 | 'beam_size': opt.sc_beam_size}) 63 | self.model.train() 64 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, 65 | opt={'sample_method':opt.train_sample_method, 66 | 'beam_size':opt.train_beam_size, 67 | 'sample_n': opt.train_sample_n}, 68 | mode='sample') 69 | gts = [gts[_] for _ in gt_indices.tolist()] 70 | reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) 71 | reward = torch.from_numpy(reward).to(sample_logprobs) 72 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward, reduction=reduction) 73 | out['reward'] = reward[:,0].mean() 74 | out['loss'] = loss 75 | return out 76 | -------------------------------------------------------------------------------- /captioning/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/ImageCaptioning.pytorch/4c48a3304932d58c5349434e7b0085f48dcb4be4/captioning/utils/__init__.py -------------------------------------------------------------------------------- /captioning/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copy from fvcore 3 | 4 | import logging 5 | import os 6 | from typing import Any 7 | import yaml 8 | from yacs.config import CfgNode as _CfgNode 9 | 10 | import io as PathManager 11 | 12 | BASE_KEY = "_BASE_" 13 | 14 | 15 | class CfgNode(_CfgNode): 16 | """ 17 | Our own extended version of :class:`yacs.config.CfgNode`. 18 | It contains the following extra features: 19 | 20 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 21 | which allows the new CfgNode to inherit all the attributes from the 22 | base configuration file. 23 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 24 | "computed" attributes. They can be inserted regardless of whether 25 | the CfgNode is frozen or not. 26 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 27 | expressions in config. See examples in 28 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 29 | Note that this may lead to arbitrary code execution: you must not 30 | load a config file from untrusted sources before manually inspecting 31 | the content of the file. 32 | """ 33 | 34 | @staticmethod 35 | def load_yaml_with_base(filename, allow_unsafe = False): 36 | """ 37 | Just like `yaml.load(open(filename))`, but inherit attributes from its 38 | `_BASE_`. 39 | 40 | Args: 41 | filename (str): the file name of the current config. Will be used to 42 | find the base config file. 43 | allow_unsafe (bool): whether to allow loading the config file with 44 | `yaml.unsafe_load`. 45 | 46 | Returns: 47 | (dict): the loaded yaml 48 | """ 49 | with PathManager.open(filename, "r") as f: 50 | try: 51 | cfg = yaml.safe_load(f) 52 | except yaml.constructor.ConstructorError: 53 | if not allow_unsafe: 54 | raise 55 | logger = logging.getLogger(__name__) 56 | logger.warning( 57 | "Loading config {} with yaml.unsafe_load. Your machine may " 58 | "be at risk if the file contains malicious content.".format( 59 | filename 60 | ) 61 | ) 62 | f.close() 63 | with open(filename, "r") as f: 64 | cfg = yaml.unsafe_load(f) 65 | 66 | def merge_a_into_b(a, b): 67 | # merge dict a into dict b. values in a will overwrite b. 68 | for k, v in a.items(): 69 | if isinstance(v, dict) and k in b: 70 | assert isinstance( 71 | b[k], dict 72 | ), "Cannot inherit key '{}' from base!".format(k) 73 | merge_a_into_b(v, b[k]) 74 | else: 75 | b[k] = v 76 | 77 | if BASE_KEY in cfg: 78 | base_cfg_file = cfg[BASE_KEY] 79 | if base_cfg_file.startswith("~"): 80 | base_cfg_file = os.path.expanduser(base_cfg_file) 81 | if not any( 82 | map(base_cfg_file.startswith, ["/", "https://", "http://"]) 83 | ): 84 | # the path to base cfg is relative to the config file itself. 85 | base_cfg_file = os.path.join( 86 | os.path.dirname(filename), base_cfg_file 87 | ) 88 | base_cfg = CfgNode.load_yaml_with_base( 89 | base_cfg_file, allow_unsafe=allow_unsafe 90 | ) 91 | del cfg[BASE_KEY] 92 | 93 | merge_a_into_b(cfg, base_cfg) 94 | return base_cfg 95 | return cfg 96 | 97 | def merge_from_file(self, cfg_filename, allow_unsafe = False): 98 | """ 99 | Merge configs from a given yaml file. 100 | 101 | Args: 102 | cfg_filename: the file name of the yaml config. 103 | allow_unsafe: whether to allow loading the config file with 104 | `yaml.unsafe_load`. 105 | """ 106 | loaded_cfg = CfgNode.load_yaml_with_base( 107 | cfg_filename, allow_unsafe=allow_unsafe 108 | ) 109 | loaded_cfg = type(self)(loaded_cfg) 110 | self.merge_from_other_cfg(loaded_cfg) 111 | 112 | # Forward the following calls to base, but with a check on the BASE_KEY. 113 | def merge_from_other_cfg(self, cfg_other): 114 | """ 115 | Args: 116 | cfg_other (CfgNode): configs to merge from. 117 | """ 118 | assert ( 119 | BASE_KEY not in cfg_other 120 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 121 | return super().merge_from_other_cfg(cfg_other) 122 | 123 | def merge_from_list(self, cfg_list): 124 | """ 125 | Args: 126 | cfg_list (list): list of configs to merge from. 127 | """ 128 | keys = set(cfg_list[0::2]) 129 | assert ( 130 | BASE_KEY not in keys 131 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 132 | return super().merge_from_list(cfg_list) 133 | 134 | def __setattr__(self, name, val): 135 | if name.startswith("COMPUTED_"): 136 | if name in self: 137 | old_val = self[name] 138 | if old_val == val: 139 | return 140 | raise KeyError( 141 | "Computed attributed '{}' already exists " 142 | "with a different value! old={}, new={}.".format( 143 | name, old_val, val 144 | ) 145 | ) 146 | self[name] = val 147 | else: 148 | super().__setattr__(name, val) 149 | 150 | 151 | if __name__ == '__main__': 152 | cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') 153 | print(cfg) -------------------------------------------------------------------------------- /captioning/utils/div_utils.py: -------------------------------------------------------------------------------- 1 | from random import uniform 2 | import numpy as np 3 | from collections import OrderedDict, defaultdict 4 | from itertools import tee 5 | import time 6 | 7 | # ----------------------------------------------- 8 | def find_ngrams(input_list, n): 9 | return zip(*[input_list[i:] for i in range(n)]) 10 | 11 | def compute_div_n(caps,n=1): 12 | aggr_div = [] 13 | for k in caps: 14 | all_ngrams = set() 15 | lenT = 0. 16 | for c in caps[k]: 17 | tkns = c.split() 18 | lenT += len(tkns) 19 | ng = find_ngrams(tkns, n) 20 | all_ngrams.update(ng) 21 | aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) 22 | return np.array(aggr_div).mean(), np.array(aggr_div) 23 | 24 | def compute_global_div_n(caps,n=1): 25 | aggr_div = [] 26 | all_ngrams = set() 27 | lenT = 0. 28 | for k in caps: 29 | for c in caps[k]: 30 | tkns = c.split() 31 | lenT += len(tkns) 32 | ng = find_ngrams(tkns, n) 33 | all_ngrams.update(ng) 34 | if n == 1: 35 | aggr_div.append(float(len(all_ngrams))) 36 | else: 37 | aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) 38 | return aggr_div[0], np.repeat(np.array(aggr_div),len(caps)) -------------------------------------------------------------------------------- /captioning/utils/eval_multi.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import numpy as np 9 | import json 10 | from json import encoder 11 | import random 12 | import string 13 | import time 14 | import os 15 | import sys 16 | from . import misc as utils 17 | from eval_utils import getCOCO 18 | 19 | from .div_utils import compute_div_n, compute_global_div_n 20 | 21 | import sys 22 | try: 23 | sys.path.append("coco-caption") 24 | annFile = 'coco-caption/annotations/captions_val2014.json' 25 | from pycocotools.coco import COCO 26 | from pycocoevalcap.eval import COCOEvalCap 27 | from pycocoevalcap.eval_spice import COCOEvalCapSpice 28 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 29 | from pycocoevalcap.bleu.bleu import Bleu 30 | sys.path.append("cider") 31 | from pyciderevalcap.cider.cider import Cider 32 | except: 33 | print('Warning: requirements for eval_multi not satisfied') 34 | 35 | 36 | def eval_allspice(dataset, preds_n, model_id, split): 37 | coco = getCOCO(dataset) 38 | valids = coco.getImgIds() 39 | 40 | capsById = {} 41 | for d in preds_n: 42 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 43 | 44 | # filter results to only those in MSCOCO validation set (will be about a third) 45 | preds_filt_n = [p for p in preds_n if p['image_id'] in valids] 46 | print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n))) 47 | cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 48 | json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API... 49 | 50 | # Eval AllSPICE 51 | cocoRes_n = coco.loadRes(cache_path_n) 52 | cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n) 53 | cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds() 54 | cocoEvalAllSPICE.evaluate() 55 | 56 | out = {} 57 | for metric, score in cocoEvalAllSPICE.eval.items(): 58 | out['All'+metric] = score 59 | 60 | imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval 61 | # collect SPICE_sub_score 62 | for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys(): 63 | if k != 'All': 64 | out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()]) 65 | out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean() 66 | for p in preds_filt_n: 67 | image_id, caption = p['image_id'], p['caption'] 68 | imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id] 69 | return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE} 70 | 71 | def eval_oracle(dataset, preds_n, model_id, split): 72 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 73 | 74 | coco = getCOCO(dataset) 75 | valids = coco.getImgIds() 76 | 77 | capsById = {} 78 | for d in preds_n: 79 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 80 | 81 | sample_n = capsById[list(capsById.keys())[0]] 82 | for i in range(len(capsById[list(capsById.keys())[0]])): 83 | preds = [_[i] for _ in capsById.values()] 84 | 85 | json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 86 | 87 | cocoRes = coco.loadRes(cache_path) 88 | cocoEval = COCOEvalCap(coco, cocoRes) 89 | cocoEval.params['image_id'] = cocoRes.getImgIds() 90 | cocoEval.evaluate() 91 | 92 | imgToEval = cocoEval.imgToEval 93 | for img_id in capsById.keys(): 94 | tmp = imgToEval[img_id] 95 | for k in tmp['SPICE'].keys(): 96 | if k != 'All': 97 | tmp['SPICE_'+k] = tmp['SPICE'][k]['f'] 98 | if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan 99 | tmp['SPICE_'+k] = -100 100 | tmp['SPICE'] = tmp['SPICE']['All']['f'] 101 | if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100 102 | capsById[img_id][i]['scores'] = imgToEval[img_id] 103 | 104 | out = {'overall': {}, 'ImgToEval': {}} 105 | for img_id in capsById.keys(): 106 | out['ImgToEval'][img_id] = {} 107 | for metric in capsById[img_id][0]['scores'].keys(): 108 | if metric == 'image_id': continue 109 | out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]]) 110 | out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id]) 111 | out['ImgToEval'][img_id]['captions'] = capsById[img_id] 112 | for metric in list(out['ImgToEval'].values())[0].keys(): 113 | if metric == 'captions': 114 | continue 115 | tmp = np.array([_[metric] for _ in out['ImgToEval'].values()]) 116 | tmp = tmp[tmp!=-100] 117 | out['overall'][metric] = tmp.mean() 118 | 119 | return out 120 | 121 | def eval_div_stats(dataset, preds_n, model_id, split): 122 | tokenizer = PTBTokenizer() 123 | 124 | capsById = {} 125 | for i, d in enumerate(preds_n): 126 | d['id'] = i 127 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 128 | 129 | n_caps_perimg = len(capsById[list(capsById.keys())[0]]) 130 | print(n_caps_perimg) 131 | _capsById = capsById # save the untokenized version 132 | capsById = tokenizer.tokenize(capsById) 133 | 134 | div_1, adiv_1 = compute_div_n(capsById,1) 135 | div_2, adiv_2 = compute_div_n(capsById,2) 136 | 137 | globdiv_1, _= compute_global_div_n(capsById,1) 138 | 139 | print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1)) 140 | 141 | # compute mbleu 142 | scorer = Bleu(4) 143 | all_scrs = [] 144 | scrperimg = np.zeros((n_caps_perimg, len(capsById))) 145 | 146 | for i in range(n_caps_perimg): 147 | tempRefsById = {} 148 | candsById = {} 149 | for k in capsById: 150 | tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:] 151 | candsById[k] = [capsById[k][i]] 152 | 153 | score, scores = scorer.compute_score(tempRefsById, candsById) 154 | all_scrs.append(score) 155 | scrperimg[i,:] = scores[1] 156 | 157 | all_scrs = np.array(all_scrs) 158 | 159 | out = {} 160 | out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1} 161 | for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()): 162 | out['overall'].update({'mBLeu_%d'%(k+1): score}) 163 | imgToEval = {} 164 | for i,imgid in enumerate(capsById.keys()): 165 | imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()} 166 | imgToEval[imgid]['individuals'] = [] 167 | for j, d in enumerate(_capsById[imgid]): 168 | imgToEval[imgid]['individuals'].append(preds_n[d['id']]) 169 | imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i] 170 | out['ImgToEval'] = imgToEval 171 | 172 | print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4') 173 | print(all_scrs.mean(axis=0)) 174 | 175 | return out 176 | 177 | def eval_self_cider(dataset, preds_n, model_id, split): 178 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') 179 | 180 | coco = getCOCO(dataset) 181 | valids = coco.getImgIds() 182 | 183 | # Get Cider_scorer 184 | Cider_scorer = Cider(df='corpus') 185 | 186 | tokenizer = PTBTokenizer() 187 | gts = {} 188 | for imgId in valids: 189 | gts[imgId] = coco.imgToAnns[imgId] 190 | gts = tokenizer.tokenize(gts) 191 | 192 | for imgId in valids: 193 | Cider_scorer.cider_scorer += (None, gts[imgId]) 194 | Cider_scorer.cider_scorer.compute_doc_freq() 195 | Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs))) 196 | 197 | # Prepare captions 198 | capsById = {} 199 | for d in preds_n: 200 | capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] 201 | 202 | capsById = tokenizer.tokenize(capsById) 203 | imgIds = list(capsById.keys()) 204 | scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds]) 205 | 206 | def get_div(eigvals): 207 | eigvals = np.clip(eigvals, 0, None) 208 | return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) 209 | sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores] 210 | score = np.mean(np.array(sc_scores)) 211 | 212 | imgToEval = {} 213 | for i, image_id in enumerate(imgIds): 214 | imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()} 215 | return {'overall': {'self_cider': score}, 'imgToEval': imgToEval} 216 | 217 | 218 | return score 219 | -------------------------------------------------------------------------------- /captioning/utils/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.optim as optim 10 | import os 11 | 12 | import torch.nn.functional as F 13 | 14 | import six 15 | from six.moves import cPickle 16 | 17 | bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that'] 18 | bad_endings += ['the'] 19 | 20 | 21 | def pickle_load(f): 22 | """ Load a pickle. 23 | Parameters 24 | ---------- 25 | f: file-like object 26 | """ 27 | if six.PY3: 28 | return cPickle.load(f, encoding='latin-1') 29 | else: 30 | return cPickle.load(f) 31 | 32 | 33 | def pickle_dump(obj, f): 34 | """ Dump a pickle. 35 | Parameters 36 | ---------- 37 | obj: pickled object 38 | f: file-like object 39 | """ 40 | if six.PY3: 41 | return cPickle.dump(obj, f, protocol=2) 42 | else: 43 | return cPickle.dump(obj, f) 44 | 45 | 46 | # modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py 47 | def serialize_to_tensor(data): 48 | device = torch.device("cpu") 49 | 50 | buffer = cPickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to(device=device) 53 | return tensor 54 | 55 | 56 | def deserialize(tensor): 57 | buffer = tensor.cpu().numpy().tobytes() 58 | return cPickle.loads(buffer) 59 | 60 | 61 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 62 | def decode_sequence(ix_to_word, seq): 63 | N, D = seq.size() 64 | out = [] 65 | for i in range(N): 66 | txt = '' 67 | for j in range(D): 68 | ix = seq[i,j] 69 | if ix > 0 : 70 | if j >= 1: 71 | txt = txt + ' ' 72 | txt = txt + ix_to_word[str(ix.item())] 73 | else: 74 | break 75 | if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): 76 | flag = 0 77 | words = txt.split(' ') 78 | for j in range(len(words)): 79 | if words[-j-1] not in bad_endings: 80 | flag = -j 81 | break 82 | txt = ' '.join(words[0:len(words)+flag]) 83 | out.append(txt.replace('@@ ', '')) 84 | return out 85 | 86 | 87 | def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''): 88 | if len(append) > 0: 89 | append = '-' + append 90 | # if checkpoint_path doesn't exist 91 | if not os.path.isdir(opt.checkpoint_path): 92 | os.makedirs(opt.checkpoint_path) 93 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) 94 | torch.save(model.state_dict(), checkpoint_path) 95 | print("model saved to {}".format(checkpoint_path)) 96 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) 97 | torch.save(optimizer.state_dict(), optimizer_path) 98 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: 99 | pickle_dump(infos, f) 100 | if histories: 101 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: 102 | pickle_dump(histories, f) 103 | 104 | 105 | def set_lr(optimizer, lr): 106 | for group in optimizer.param_groups: 107 | group['lr'] = lr 108 | 109 | def get_lr(optimizer): 110 | for group in optimizer.param_groups: 111 | return group['lr'] 112 | 113 | 114 | def build_optimizer(params, opt): 115 | if opt.optim == 'rmsprop': 116 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) 117 | elif opt.optim == 'adagrad': 118 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) 119 | elif opt.optim == 'sgd': 120 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) 121 | elif opt.optim == 'sgdm': 122 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) 123 | elif opt.optim == 'sgdmom': 124 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) 125 | elif opt.optim == 'adam': 126 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) 127 | elif opt.optim == 'adamw': 128 | return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) 129 | else: 130 | raise Exception("bad option opt.optim: {}".format(opt.optim)) 131 | 132 | 133 | def penalty_builder(penalty_config): 134 | if penalty_config == '': 135 | return lambda x,y: y 136 | pen_type, alpha = penalty_config.split('_') 137 | alpha = float(alpha) 138 | if pen_type == 'wu': 139 | return lambda x,y: length_wu(x,y,alpha) 140 | if pen_type == 'avg': 141 | return lambda x,y: length_average(x,y,alpha) 142 | 143 | def length_wu(length, logprobs, alpha=0.): 144 | """ 145 | NMT length re-ranking score from 146 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 147 | """ 148 | 149 | modifier = (((5 + length) ** alpha) / 150 | ((5 + 1) ** alpha)) 151 | return (logprobs / modifier) 152 | 153 | def length_average(length, logprobs, alpha=0.): 154 | """ 155 | Returns the average probability of tokens in a sequence. 156 | """ 157 | return logprobs / length 158 | 159 | 160 | class NoamOpt(torch.optim.Optimizer): 161 | "Optim wrapper that implements rate." 162 | def __init__(self, model_size, factor, warmup, optimizer): 163 | self.optimizer = optimizer 164 | self._step = 0 165 | self.warmup = warmup 166 | self.factor = factor 167 | self.model_size = model_size 168 | self._rate = 0 169 | 170 | def step(self, *args, **kwargs): 171 | "Update parameters and rate" 172 | self._step += 1 173 | rate = self.rate() 174 | for p in self.optimizer.param_groups: 175 | p['lr'] = rate 176 | self._rate = rate 177 | self.optimizer.step(*args, **kwargs) 178 | 179 | def rate(self, step = None): 180 | "Implement `lrate` above" 181 | if step is None: 182 | step = self._step 183 | return self.factor * \ 184 | (self.model_size ** (-0.5) * 185 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 186 | 187 | def __getattr__(self, name): 188 | return getattr(self.optimizer, name) 189 | 190 | def state_dict(self): 191 | state_dict = self.optimizer.state_dict() 192 | state_dict['_step'] = self._step 193 | return state_dict 194 | 195 | def load_state_dict(self, state_dict): 196 | if '_step' in state_dict: 197 | self._step = state_dict['_step'] 198 | del state_dict['_step'] 199 | self.optimizer.load_state_dict(state_dict) 200 | 201 | class ReduceLROnPlateau(torch.optim.Optimizer): 202 | "Optim wrapper that implements rate." 203 | 204 | def __init__(self, 205 | optimizer, 206 | mode='min', 207 | factor=0.1, 208 | patience=10, 209 | threshold=0.0001, 210 | threshold_mode='rel', 211 | cooldown=0, 212 | min_lr=0, 213 | eps=1e-08, 214 | verbose=False): 215 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 216 | optimizer, mode, factor, patience, threshold, threshold_mode, 217 | cooldown, min_lr, eps, verbose) 218 | self.optimizer = optimizer 219 | self.current_lr = get_lr(optimizer) 220 | 221 | def step(self, *args, **kwargs): 222 | "Update parameters and rate" 223 | self.optimizer.step(*args, **kwargs) 224 | 225 | def scheduler_step(self, val): 226 | self.scheduler.step(val) 227 | self.current_lr = get_lr(self.optimizer) 228 | 229 | def state_dict(self): 230 | return {'current_lr':self.current_lr, 231 | 'scheduler_state_dict': self.scheduler.state_dict(), 232 | 'optimizer_state_dict': self.optimizer.state_dict()} 233 | 234 | def load_state_dict(self, state_dict): 235 | if 'current_lr' not in state_dict: 236 | # it's normal optimizer 237 | self.optimizer.load_state_dict(state_dict) 238 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option 239 | else: 240 | # it's a schduler 241 | self.current_lr = state_dict['current_lr'] 242 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) 243 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 244 | # current_lr is actually useless in this case 245 | 246 | def rate(self, step = None): 247 | "Implement `lrate` above" 248 | if step is None: 249 | step = self._step 250 | return self.factor * \ 251 | (self.model_size ** (-0.5) * 252 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 253 | 254 | def __getattr__(self, name): 255 | return getattr(self.optimizer, name) 256 | 257 | def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): 258 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, 259 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 260 | optim_func = dict(adam=torch.optim.Adam, 261 | adamw=torch.optim.AdamW)[optim_func] 262 | return NoamOpt(model.d_model, factor, warmup, 263 | optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) -------------------------------------------------------------------------------- /captioning/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.resnet 4 | from torchvision.models.resnet import BasicBlock, Bottleneck 5 | 6 | class ResNet(torchvision.models.resnet.ResNet): 7 | def __init__(self, block, layers, num_classes=1000): 8 | super(ResNet, self).__init__(block, layers, num_classes) 9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 10 | for i in range(2, 5): 11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 13 | 14 | def resnet18(pretrained=False): 15 | """Constructs a ResNet-18 model. 16 | 17 | Args: 18 | pretrained (bool): If True, returns a model pre-trained on ImageNet 19 | """ 20 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 21 | if pretrained: 22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 23 | return model 24 | 25 | 26 | def resnet34(pretrained=False): 27 | """Constructs a ResNet-34 model. 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 33 | if pretrained: 34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 35 | return model 36 | 37 | 38 | def resnet50(pretrained=False): 39 | """Constructs a ResNet-50 model. 40 | 41 | Args: 42 | pretrained (bool): If True, returns a model pre-trained on ImageNet 43 | """ 44 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 45 | if pretrained: 46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 47 | return model 48 | 49 | 50 | def resnet101(pretrained=False): 51 | """Constructs a ResNet-101 model. 52 | 53 | Args: 54 | pretrained (bool): If True, returns a model pre-trained on ImageNet 55 | """ 56 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 57 | if pretrained: 58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 59 | return model 60 | 61 | 62 | def resnet152(pretrained=False): 63 | """Constructs a ResNet-152 model. 64 | 65 | Args: 66 | pretrained (bool): If True, returns a model pre-trained on ImageNet 67 | """ 68 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 69 | if pretrained: 70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 71 | return model -------------------------------------------------------------------------------- /captioning/utils/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class myResnet(nn.Module): 6 | def __init__(self, resnet): 7 | super(myResnet, self).__init__() 8 | self.resnet = resnet 9 | 10 | def forward(self, img, att_size=14): 11 | x = img.unsqueeze(0) 12 | 13 | x = self.resnet.conv1(x) 14 | x = self.resnet.bn1(x) 15 | x = self.resnet.relu(x) 16 | x = self.resnet.maxpool(x) 17 | 18 | x = self.resnet.layer1(x) 19 | x = self.resnet.layer2(x) 20 | x = self.resnet.layer3(x) 21 | x = self.resnet.layer4(x) 22 | 23 | fc = x.mean(3).mean(2).squeeze() 24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 25 | 26 | return fc, att 27 | 28 | -------------------------------------------------------------------------------- /captioning/utils/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | from collections import OrderedDict 8 | import torch 9 | 10 | import sys 11 | try: 12 | sys.path.append("cider") 13 | from pyciderevalcap.ciderD.ciderD import CiderD 14 | from pyciderevalcap.cider.cider import Cider 15 | sys.path.append("coco-caption") 16 | from pycocoevalcap.bleu.bleu import Bleu 17 | except: 18 | print('cider or coco-caption missing') 19 | 20 | CiderD_scorer = None 21 | Cider_scorer = None 22 | Bleu_scorer = None 23 | #CiderD_scorer = CiderD(df='corpus') 24 | 25 | def init_scorer(cached_tokens): 26 | global CiderD_scorer 27 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 28 | global Cider_scorer 29 | Cider_scorer = Cider_scorer or Cider(df=cached_tokens) 30 | global Bleu_scorer 31 | Bleu_scorer = Bleu_scorer or Bleu(4) 32 | 33 | def array_to_str(arr): 34 | out = '' 35 | for i in range(len(arr)): 36 | out += str(arr[i]) + ' ' 37 | if arr[i] == 0: 38 | break 39 | return out.strip() 40 | 41 | def get_self_critical_reward(greedy_res, data_gts, gen_result, opt): 42 | batch_size = len(data_gts) 43 | gen_result_size = gen_result.shape[0] 44 | seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img 45 | assert greedy_res.shape[0] == batch_size 46 | 47 | res = OrderedDict() 48 | gen_result = gen_result.data.cpu().numpy() 49 | greedy_res = greedy_res.data.cpu().numpy() 50 | for i in range(gen_result_size): 51 | res[i] = [array_to_str(gen_result[i])] 52 | for i in range(batch_size): 53 | res[gen_result_size + i] = [array_to_str(greedy_res[i])] 54 | 55 | gts = OrderedDict() 56 | for i in range(len(data_gts)): 57 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] 58 | 59 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] 60 | res__ = {i: res[i] for i in range(len(res_))} 61 | gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} 62 | gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)}) 63 | if opt.cider_reward_weight > 0: 64 | _, cider_scores = CiderD_scorer.compute_score(gts_, res_) 65 | print('Cider scores:', _) 66 | else: 67 | cider_scores = 0 68 | if opt.bleu_reward_weight > 0: 69 | _, bleu_scores = Bleu_scorer.compute_score(gts_, res__) 70 | bleu_scores = np.array(bleu_scores[3]) 71 | print('Bleu scores:', _[3]) 72 | else: 73 | bleu_scores = 0 74 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores 75 | 76 | scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] 77 | scores = scores.reshape(gen_result_size) 78 | 79 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 80 | 81 | return rewards 82 | 83 | def get_scores(data_gts, gen_result, opt): 84 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 85 | seq_per_img = batch_size // len(data_gts) 86 | 87 | res = OrderedDict() 88 | 89 | gen_result = gen_result.data.cpu().numpy() 90 | for i in range(batch_size): 91 | res[i] = [array_to_str(gen_result[i])] 92 | 93 | gts = OrderedDict() 94 | for i in range(len(data_gts)): 95 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] 96 | 97 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)] 98 | res__ = {i: res[i] for i in range(batch_size)} 99 | gts = {i: gts[i // seq_per_img] for i in range(batch_size)} 100 | if opt.cider_reward_weight > 0: 101 | _, cider_scores = CiderD_scorer.compute_score(gts, res_) 102 | print('Cider scores:', _) 103 | else: 104 | cider_scores = 0 105 | if opt.bleu_reward_weight > 0: 106 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__) 107 | bleu_scores = np.array(bleu_scores[3]) 108 | print('Bleu scores:', _[3]) 109 | else: 110 | bleu_scores = 0 111 | 112 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores 113 | 114 | return scores 115 | 116 | def get_self_cider_scores(data_gts, gen_result, opt): 117 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 118 | seq_per_img = batch_size // len(data_gts) 119 | 120 | res = [] 121 | 122 | gen_result = gen_result.data.cpu().numpy() 123 | for i in range(batch_size): 124 | res.append(array_to_str(gen_result[i])) 125 | 126 | scores = [] 127 | for i in range(len(data_gts)): 128 | tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]]) 129 | def get_div(eigvals): 130 | eigvals = np.clip(eigvals, 0, None) 131 | return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) 132 | scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10))) 133 | 134 | scores = np.array(scores) 135 | 136 | return scores -------------------------------------------------------------------------------- /configs/a2i2.yml: -------------------------------------------------------------------------------- 1 | # base 2 | caption_model: att2in2 3 | input_json: data/cocotalk.json 4 | input_att_dir: data/cocobu_att 5 | input_label_h5: data/cocotalk_label.h5 6 | learning_rate: 0.0005 7 | learning_rate_decay_start: 0 8 | scheduled_sampling_start: 0 9 | # checkpoint_path: $ckpt_path 10 | # $start_from 11 | language_eval: 1 12 | save_checkpoint_every: 3000 13 | val_images_use: 5000 14 | 15 | train_sample_n: 5 16 | self_critical_after: 30 17 | batch_size: 10 18 | learning_rate_decay_start: 0 19 | max_epochs: 30 20 | -------------------------------------------------------------------------------- /configs/a2i2_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: a2i2.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | self_critical_after: -1 5 | structure_after: 30 6 | train_sample_n: 5 7 | structure_loss_weight: 1 8 | structure_loss_type: new_self_critical 9 | max_epochs: 50 10 | -------------------------------------------------------------------------------- /configs/a2i2_sc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: a2i2.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | s: 50 -------------------------------------------------------------------------------- /configs/aoa.yml: -------------------------------------------------------------------------------- 1 | # id: aoanet 2 | caption_model: aoa 3 | 4 | # AOA config 5 | refine: 1 6 | refine_aoa: 1 7 | use_ff: 0 8 | decoder_type: AoA 9 | use_multi_head: 2 10 | num_heads: 8 11 | multi_head_scale: 1 12 | mean_feats: 1 13 | ctx_drop: 1 14 | dropout_aoa: 0.3 15 | 16 | label_smoothing: 0.2 17 | input_json: data/cocotalk.json 18 | input_label_h5: data/cocotalk_label.h5 19 | input_fc_dir: data/cocobu_fc 20 | input_att_dir: data/cocobu_att 21 | input_box_dir: data/cocobu_box 22 | 23 | seq_per_img: 5 24 | batch_size: 10 25 | beam_size: 1 26 | learning_rate: 0.0002 27 | num_layers: 2 28 | input_encoding_size: 1024 29 | rnn_size: 1024 30 | learning_rate_decay_start: 0 31 | scheduled_sampling_start: 0 32 | save_checkpoint_every: 6000 33 | language_eval: 1 34 | val_images_use: -1 35 | max_epochs: 25 36 | scheduled_sampling_increase_every: 5 37 | scheduled_sampling_max_prob: 0.5 38 | learning_rate_decay_every: 3 -------------------------------------------------------------------------------- /configs/aoa_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: aoa_sc.yml 2 | 3 | structure_after: -1 4 | structure_loss_weight: 1 5 | structure_loss_type: new_self_critical -------------------------------------------------------------------------------- /configs/aoa_sc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: aoa.yml 2 | 3 | save_checkpoint_every: 3000 4 | learning_rate: 0.00002 5 | max_epochs: 40 6 | learning_rate_decay_start: -1 7 | scheduled_sampling_start: -1 8 | reduce_on_plateau: true 9 | 10 | 11 | train_sample_n: 5 12 | self_critical_after: 0 -------------------------------------------------------------------------------- /configs/fc.yml: -------------------------------------------------------------------------------- 1 | caption_model: newfc 2 | input_json: data/cocotalk.json 3 | input_fc_dir: data/cocotalk_fc 4 | input_att_dir: data/cocotalk_att 5 | input_label_h5: data/cocotalk_label.h5 6 | learning_rate: 0.0005 7 | learning_rate_decay_start: 0 8 | scheduled_sampling_start: 0 9 | # checkpoint_path: $ckpt_path 10 | # $start_from 11 | language_eval: 1 12 | save_checkpoint_every: 3000 13 | val_images_use: 5000 14 | 15 | batch_size: 10 16 | max_epochs: 30 -------------------------------------------------------------------------------- /configs/fc_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: fc.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | scheduled_sampling_start: -1 5 | 6 | language_eval: 1 7 | save_checkpoint_every: 3000 8 | val_images_use: 5000 9 | 10 | batch_size: 10 11 | max_epochs: 50 12 | cached_tokens: coco-train-idxs 13 | 14 | 15 | self_critical_after: -1 16 | structure_after: 30 17 | train_sample_n: 5 18 | structure_loss_weight: 1 19 | structure_loss_type: new_self_critical 20 | -------------------------------------------------------------------------------- /configs/fc_rl.yml: -------------------------------------------------------------------------------- 1 | _BASE_: fc.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | scheduled_sampling_start: -1 5 | 6 | language_eval: 1 7 | save_checkpoint_every: 3000 8 | val_images_use: 5000 9 | 10 | batch_size: 10 11 | max_epochs: 50 12 | self_critical_after: 30 13 | cached_tokens: coco-train-idxs 14 | 15 | train_sample_n: 5 -------------------------------------------------------------------------------- /configs/transformer/transformer.yml: -------------------------------------------------------------------------------- 1 | caption_model: transformer 2 | noamopt: true 3 | noamopt_warmup: 20000 4 | label_smoothing: 0.0 5 | input_json: data/cocotalk.json 6 | input_label_h5: data/cocotalk_label.h5 7 | input_att_dir: data/cocobu_att 8 | seq_per_img: 5 9 | batch_size: 10 10 | learning_rate: 0.0005 11 | 12 | # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 13 | # N=num_layers 14 | # d_model=input_encoding_size 15 | # d_ff=rnn_size 16 | 17 | # will be ignored 18 | num_layers: 6 19 | input_encoding_size: 512 20 | rnn_size: 2048 21 | 22 | # Transformer config 23 | N_enc: 6 24 | N_dec: 6 25 | d_model: 512 26 | d_ff: 2048 27 | num_att_heads: 8 28 | dropout: 0.1 29 | 30 | 31 | learning_rate_decay_start: 0 32 | scheduled_sampling_start: -1 33 | save_checkpoint_every: 3000 34 | language_eval: 1 35 | val_images_use: 5000 36 | max_epochs: 15 37 | train_sample_n: 5 38 | 39 | REFORWARD: false -------------------------------------------------------------------------------- /configs/transformer/transformer_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: transformer.yml 2 | reduce_on_plateau: true 3 | noamopt: false 4 | learning_rate: 0.00001 5 | learning_rate_decay_start: -1 6 | 7 | self_critical_after: -1 8 | structure_after: 15 9 | train_sample_n: 5 10 | structure_loss_weight: 1 11 | structure_loss_type: new_self_critical 12 | 13 | max_epochs: 30 -------------------------------------------------------------------------------- /configs/transformer/transformer_nscl.yml: -------------------------------------------------------------------------------- 1 | _BASE_: transformer.yml 2 | reduce_on_plateau: false 3 | noamopt: false 4 | learning_rate: 0.000005 5 | learning_rate_decay_start: -1 6 | 7 | self_critical_after: -1 8 | structure_after: 15 9 | train_sample_n: 5 10 | structure_loss_weight: 1 11 | structure_loss_type: new_self_critical 12 | 13 | max_epochs: 40 14 | -------------------------------------------------------------------------------- /configs/transformer/transformer_sc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: transformer.yml 2 | reduce_on_plateau: true 3 | noamopt: false 4 | learning_rate: 0.00001 5 | learning_rate_decay_start: -1 6 | 7 | self_critical_after: 15 8 | max_epochs: 50 -------------------------------------------------------------------------------- /configs/transformer/transformer_scl.yml: -------------------------------------------------------------------------------- 1 | _BASE_: transformer.yml 2 | reduce_on_plateau: false 3 | noamopt: false 4 | learning_rate: 0.000005 5 | learning_rate_decay_start: -1 6 | 7 | self_critical_after: 15 8 | max_epochs: 40 9 | -------------------------------------------------------------------------------- /configs/transformer/transformer_step.yml: -------------------------------------------------------------------------------- 1 | _BASE_: transformer.yml 2 | # from https://arxiv.org/pdf/2003.08897.pdf 3 | noamopt: false 4 | use_warmup: true 5 | noamopt_warmup: 33000 6 | learning_rate: 0.0003 7 | learning_rate_decay_rate: 0.5 8 | learning_rate_decay_start: 3 -------------------------------------------------------------------------------- /configs/updown/ud_long_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: updown_long.yml 2 | self_critical_after: -1 3 | structure_after: 40 4 | train_sample_n: 5 5 | structure_loss_weight: 1 6 | structure_loss_type: new_self_critical 7 | -------------------------------------------------------------------------------- /configs/updown/ud_long_sc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: updown_long.yml 2 | -------------------------------------------------------------------------------- /configs/updown/updown.yml: -------------------------------------------------------------------------------- 1 | # base 2 | caption_model: updown 3 | input_json: data/cocotalk.json 4 | input_att_dir: data/cocobu_att 5 | input_label_h5: data/cocotalk_label.h5 6 | learning_rate: 0.0005 7 | scheduled_sampling_start: 0 8 | # checkpoint_path: $ckpt_path 9 | # $start_from 10 | language_eval: 1 11 | save_checkpoint_every: 3000 12 | val_images_use: 5000 13 | rnn_size: 1000 14 | input_encoding_size: 1000 15 | att_feat_size: 2048 16 | att_hid_size: 512 17 | 18 | train_sample_n: 5 19 | self_critical_after: 30 20 | batch_size: 10 21 | learning_rate_decay_start: 0 22 | max_epochs: 30 23 | -------------------------------------------------------------------------------- /configs/updown/updown_long.yml: -------------------------------------------------------------------------------- 1 | _BASE_: updown.yml 2 | # This training schedule is provided by yangxuntu 3 | self_critical_after: 40 4 | train_sample_n: 5 5 | batch_size: 100 6 | learning_rate_decay_start: 0 7 | learning_rate_decay_every: 5 8 | max_epochs: 150 9 | -------------------------------------------------------------------------------- /configs/updown/updown_nsc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: updown.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | self_critical_after: -1 5 | structure_after: 30 6 | train_sample_n: 5 7 | structure_loss_weight: 1 8 | structure_loss_type: new_self_critical 9 | max_epochs: 50 10 | -------------------------------------------------------------------------------- /configs/updown/updown_sc.yml: -------------------------------------------------------------------------------- 1 | _BASE_: updown.yml 2 | learning_rate: 0.00005 3 | learning_rate_decay_start: -1 4 | 5 | max_epochs: 50 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Prepare data 2 | 3 | Note: every preprocessed file or preextracted features can be found in [link](https://drive.google.com/open?id=1eCdz62FAVCGogOuNhy87Nmlo5_I0sH2J). 4 | 5 | ## COCO 6 | 7 | ### Download COCO captions and preprocess them 8 | 9 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. 10 | 11 | Then do: 12 | 13 | ```bash 14 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk 15 | ``` 16 | 17 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. 18 | 19 | ### Image features option 1: Resnet features 20 | 21 | #### Download COCO dataset and pre-extract the image features(if you want to extract your self) 22 | 23 | Download pretrained resnet models. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. 24 | 25 | Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. 26 | 27 | Then: 28 | 29 | ``` 30 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT 31 | ``` 32 | 33 | 34 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. 35 | 36 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.) 37 | 38 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. 39 | 40 | #### Download preextracted features 41 | 42 | To skip the preprocessing, you can download and decompress `cocotalk_att.tar` and `cocotalk_fc.tar` from the link provided at the beginning.) 43 | 44 | ### Image features option 2: Bottom-up features (current standard) 45 | 46 | #### Convert from peteanderson80's original file 47 | Download pre-extracted features from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one. 48 | 49 | For example: 50 | ``` 51 | mkdir data/bu_data; cd data/bu_data 52 | wget https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip 53 | unzip trainval.zip 54 | 55 | ``` 56 | 57 | Then: 58 | 59 | ```bash 60 | python script/make_bu_data.py --output_dir data/cocobu 61 | ``` 62 | 63 | This will create `data/cocobu_fc`(not necessary), `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just replace all `"cocotalk"` with `"cocobu"` in the training/test scripts. 64 | 65 | #### Download converted files 66 | 67 | bottomup-att: [link](https://drive.google.com/file/d/1hun0tsel34aXO4CYyTRIvHJkcbZHwjrD/view?usp=sharing) 68 | 69 | ### Image features option 3: Vilbert 12 in 1 features. 70 | In vilbert-12-in-1, the image features used is similar to the original bottom-up feature but with a model with renext152 backbone. 71 | 72 | Here is the link of the converted lmdb(More compressed than the original one provided by jiasen): 73 | 74 | [https://drive.google.com/file/d/1Gjo9Xs7qrjah2TQs0-joEWi8HabCkuQp/view?usp=sharing](https://drive.google.com/file/d/1Gjo9Xs7qrjah2TQs0-joEWi8HabCkuQp/view?usp=sharing) 75 | 76 | ## Flickr30k. 77 | 78 | It's similar. 79 | 80 | ``` 81 | python scripts/prepro_labels.py --input_json data/dataset_flickr30k.json --output_json data/f30ktalk.json --output_h5 data/f30ktalk 82 | 83 | python scripts/prepro_ngrams.py --input_json data/dataset_flickr30k.json --dict_json data/f30ktalk.json --output_pkl data/f30k-train --split train 84 | ``` 85 | 86 | This is to generate the coco-like annotation file for evaluation using coco-caption. 87 | 88 | ``` 89 | python scripts/prepro_reference_json.py --input_json data/dataset_flickr30k.json --output_json data/f30k_captions4eval.json 90 | ``` 91 | 92 | ### Feature extraction 93 | 94 | For resnet feature, you can do the same thing as COCO. 95 | 96 | For bottom-up feature, you can download from [link](https://github.com/kuanghuei/SCAN) 97 | 98 | `wget https://scanproject.blob.core.windows.net/scan-data/data.zip` 99 | 100 | and then convert to a pth file using the following script: 101 | 102 | ``` 103 | import numpy as np 104 | import os 105 | import torch 106 | from tqdm import tqdm 107 | 108 | out = {} 109 | def transform(id_file, feat_file): 110 | ids = open(id_file, 'r').readlines() 111 | ids = [_.strip('\n') for _ in ids] 112 | feats = np.load(feat_file) 113 | assert feats.shape[0] == len(ids) 114 | for _id, _feat in tqdm(zip(ids, feats)): 115 | out[str(_id)] = _feat 116 | 117 | transform('dev_ids.txt', 'dev_ims.npy') 118 | transform('train_ids.txt', 'train_ims.npy') 119 | transform('test_ids.txt', 'test_ims.npy') 120 | 121 | torch.save(out, 'f30kbu_att.pth') 122 | ``` -------------------------------------------------------------------------------- /projects/Diversity/README.md: -------------------------------------------------------------------------------- 1 | # Analysis of diversity-accuracy tradeoff in image captioning [[arxiv]](https://arxiv.org/abs/2002.11848) 2 | 3 | ## Abstract 4 | 5 | We investigate the effect of different model architectures, training objectives, hyperparameter settings and decoding procedures on the diversity of automatically generated image captions. Our results show that 1) simple decoding by naive sampling, coupled with low temperature is a competitive and fast method to produce diverse and accurate caption sets; 2) training with CIDEr-based reward using Reinforcement learning harms the diversity properties of the resulting generator, which cannot be mitigated by manipulating decoding parameters. In addition, we propose a new metric AllSPICE for evaluating both accuracy and diversity of a set of captions by a single value. 6 | 7 | ## AllSPICE 8 | 9 | The instruction of AllSPICE has been added to [ruotianluo/coco-caption](https://github.com/ruotianluo/coco-caption). Read the paper for more details of this metric. 10 | 11 | ## Reproduce the figures 12 | 13 | In [drive](https://drive.google.com/open?id=1TILv8GXM0dIcjWnrM5V2D7tJvmqupf49), we provide the original evaluation results. 14 | To get the figures (the exact same scores) in the paper, you can run `plot.ipynb`. 15 | 16 | ## Training evaluation scripts 17 | 18 | ### Training 19 | To train the model used in the main paper, run `run_a2i2.sh` first then run `run_a2i2_npg.sh 1` or `run_a2i2_sf_npg.sh 0` (RL) and `run_a2i2_npg.sh 0` (XE). To get XE+RL, run `run_a2i2_npg.sh x` where x is the weighting factor. 20 | 21 | Similar for a2i2l, fc, transf if you want to reproduce the results in appendix. 22 | 23 | #### Pretrained models 24 | I also provide pretrained models [link](https://drive.google.com/open?id=1HdEzL-3Bl-uwALlwonLBxyd1zuen2DCO). Note that even with the same model, it's not guaranteed to get the same number for diversity scores because there is randomness in sampling. However, from my experience the numbers are usually close. 25 | 26 | ### Evaluation 27 | 28 | `only_gen_test_n_*.sh` generates the caption sets and `only_eval_test_n_*.sh` evaluates the results. 29 | 30 | `*` corresponds to different sampling methods. Check each scripts to see what arguments can be specified. Here is an example: 31 | 32 | In `only_gen_test_n_dbst.sh a2i2_npg_0 0.3 1 5`, `a2i2_npg_0` is the model id, `0.3` is the diversity lambda, `1` is the sampling temperature, `5` is the sample size. 33 | 34 | ## Reference 35 | If you find this work helpful, please cite this paper: 36 | 37 | ``` 38 | @article{luo2020analysis, 39 | title={Analysis of diversity-accuracy tradeoff in image captioning}, 40 | author={Luo, Ruotian and Shakhnarovich, Gregory}, 41 | journal={arXiv preprint arXiv:2002.11848}, 42 | year={2020} 43 | } 44 | ``` -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_eval_test_n_bs.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | python eval.py --batch_size 10 --image_root /share/data/vision-greg/coco/ --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 1 --only_lang_eval 1 --beam_size 5 --sample_n $3 --temperature $2 --sample_method greedy --sample_n_method bs --infos_path log_$id/infos_$id-best.pkl --id $id"_bs_"$2_$3 6 | 7 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_eval_test_n_dbst.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | python eval.py --batch_size 1 --diversity_lambda $2 --image_root /share/data/vision-greg/coco/ --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --only_lang_eval 1 --language_eval 1 --beam_size 5 --sample_n $4 --temperature $3 --sample_method greedy --sample_n_method dbs --infos_path log_$id/infos_$id-best.pkl --id $id"_dbst_"$2_$3_$4 6 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_eval_test_n_sp.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | id=$1 3 | 4 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size 100 --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --only_lang_eval 1 --language_eval 1 --beam_size 5 --sample_n $3 --temperature $2 --sample_method greedy --sample_n_method sample --infos_path log_$id/infos_$id-best.pkl --id $4$id"_sp_"$2_$3 5 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_eval_test_n_topk.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | 4 | id=$1 5 | 6 | 7 | 8 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size 100 --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --only_lang_eval 1 --language_eval 1 --beam_size 5 --sample_n $4 --temperature $2 --sample_method greedy --sample_n_method top$3 --infos_path log_$id/infos_$id-best.pkl --id $5$id"_tk_"$2_$3_$4 9 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_eval_test_n_topp.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | 6 | 7 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size 100 --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --only_lang_eval 1 --language_eval 1 --beam_size 5 --sample_n $4 --temperature $2 --sample_method greedy --sample_n_method top$3 --infos_path log_$id/infos_$id-best.pkl --id $5$id"_tp_"$2_$3_$4 8 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_gen_test_n_bs.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | python eval.py --batch_size 10 --image_root /share/data/vision-greg/coco/ --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 0 --temperature $2 --beam_size 5 --sample_n $3 --sample_method greedy --sample_n_method bs --infos_path log_$id/infos_$id-best.pkl --id $id"_bs_"$2_$3 6 | 7 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_gen_test_n_dbst.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | python eval.py --batch_size 1 --diversity_lambda $2 --image_root /share/data/vision-greg/coco/ --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 0 --beam_size 5 --sample_n $4 --temperature $3 --sample_method greedy --sample_n_method dbs --infos_path log_$id/infos_$id-best.pkl --id $id"_dbst_"$2_$3_$4 6 | 7 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_gen_test_n_sp.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | id=$1 3 | 4 | bs=$[500/$3] 5 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size $bs --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 0 --beam_size 5 --sample_n $3 --temperature $2 --sample_method greedy --sample_n_method sample --infos_path log_$id/infos_$id-best.pkl --id $4$id"_sp_"$2_$3 6 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_gen_test_n_topk.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | 4 | id=$1 5 | 6 | 7 | 8 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size 100 --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 0 --beam_size 5 --sample_n $4 --temperature $2 --sample_method greedy --sample_n_method top$3 --infos_path log_$id/infos_$id-best.pkl --id $5$id"_tk_"$2_$3_$4 9 | 10 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/eval_scripts/only_gen_test_n_topp.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | set -m 3 | id=$1 4 | 5 | 6 | 7 | python eval.py --image_root /share/data/vision-greg/coco/ --batch_size 100 --dump_images 0 --num_images -1 --split test --model log_$id/model-best.pth --language_eval 0 --beam_size 5 --sample_n $4 --temperature $2 --sample_method greedy --sample_n_method top$3 --infos_path log_$id/infos_$id-best.pkl --id $5$id"_tp_"$2_$3_$4 8 | 9 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="a2i2" 4 | ckpt_path="log_"$id 5 | if [ ! -d $ckpt_path ]; then 6 | mkdir $ckpt_path 7 | fi 8 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 9 | start_from="" 10 | else 11 | start_from="--start_from "$ckpt_path 12 | fi 13 | 14 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --beam_size 1 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 15 | 16 | #if [ ! -d xe/$ckpt_path ]; then 17 | #cp -r $ckpt_path xe/ 18 | #fi 19 | 20 | #python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 500 --language_eval 1 --val_images_use 5000 --self_critical_after 29 21 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="a2i2_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh a2i2 $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 100 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight $1 --structure_loss_type new_policy_gradient --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2_pgg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="a2i2_pgg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh a2i2 $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --learning_rate 4.294967296000003e-05 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 50 --structure_after 28 --structure_sample_n 3 --structure_loss_weight $1 --structure_loss_type policy_gradient 13 | 14 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2_sf_npg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="a2i2_sf_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh a2i2 $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 100 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight 1 --structure_loss_type new_policy_gradient --self_critical_reward_weight $1 --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2l.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="a2i2l" 4 | ckpt_path="log_"$id 5 | if [ ! -d $ckpt_path ]; then 6 | mkdir $ckpt_path 7 | fi 8 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 9 | start_from="" 10 | else 11 | start_from="--start_from "$ckpt_path 12 | fi 13 | 14 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --rnn_size 2048 --beam_size 1 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 15 | 16 | #if [ ! -d xe/$ckpt_path ]; then 17 | #cp -r $ckpt_path xe/ 18 | #fi 19 | 20 | #python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 500 --language_eval 1 --val_images_use 5000 --self_critical_after 29 21 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2l_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="a2i2l_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh a2i2l $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --rnn_size 2048 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight $1 --structure_loss_type new_policy_gradient --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_a2i2l_sf_npg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="a2i2l_sf_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh a2i2l $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --rnn_size 2048 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight 1 --structure_loss_type new_policy_gradient --self_critical_reward_weight $1 --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_fc.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="fc" 4 | ckpt_path="log_"$id 5 | if [ ! -d $ckpt_path ]; then 6 | mkdir $ckpt_path 7 | fi 8 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 9 | start_from="" 10 | else 11 | start_from="--start_from "$ckpt_path 12 | fi 13 | 14 | python train.py --id $id --caption_model newfc --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --beam_size 1 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 15 | 16 | #if [ ! -d xe/$ckpt_path ]; then 17 | #cp -r $ckpt_path xe/ 18 | #fi 19 | 20 | #python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 500 --language_eval 1 --val_images_use 5000 --self_critical_after 29 21 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_fc_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="fc_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh fc $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model newfc --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 100 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight $1 --structure_loss_type new_policy_gradient --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_td.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | id="td" 4 | ckpt_path="log_"$id 5 | if [ ! -d $ckpt_path ]; then 6 | mkdir $ckpt_path 7 | fi 8 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 9 | start_from="" 10 | else 11 | start_from="--start_from "$ckpt_path 12 | fi 13 | 14 | python train.py --id $id --caption_model topdown --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 50 --beam_size 1 --learning_rate 5e-4 --rnn_size 1000 --input_encoding_size 1000 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 15 | 16 | #if [ ! -d xe/$ckpt_path ]; then 17 | #cp -r $ckpt_path xe/ 18 | #fi 19 | 20 | #python train.py --id $id --caption_model att2in2 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 500 --language_eval 1 --val_images_use 5000 --self_critical_after 29 21 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_td_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="td_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh td $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model topdown --rnn_size 1000 --input_encoding_size 1000 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 100 --beam_size 1 --learning_rate 4.3e-5 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 100 --structure_after 28 --structure_sample_n 1 --structure_loss_weight $1 --structure_loss_type new_self_critical --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_transf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | # fix warmup 4 | 5 | id="transf" 6 | ckpt_path="log_"$id 7 | if [ ! -d $ckpt_path ]; then 8 | mkdir $ckpt_path 9 | fi 10 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 11 | start_from="" 12 | else 13 | start_from="--start_from "$ckpt_path 14 | fi 15 | 16 | python train.py --id $id --caption_model transformer --noamopt --noamopt_warmup 10000 --label_smoothing 0.0 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 20 --beam_size 1 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 15 17 | 18 | #if [ ! -d xe/$ckpt_path ]; then 19 | #cp -r $ckpt_path xe/ 20 | #fi 21 | 22 | #python train.py --id $id --caption_model transformer --reduce_on_plateau --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att.lmdb --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --self_critical_after 10 23 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_transf_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="transf_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh transf $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model transformer --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 --structure_after 14 --structure_sample_n 1 --structure_loss_weight $1 --structure_loss_type new_policy_gradient --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/Diversity/scripts/train_scripts/run_transf_sf_npg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | id="transf_sf_npg_"$1 4 | ckpt_path="log_"$id 5 | 6 | if [ ! -d $ckpt_path ]; then 7 | bash scripts/copy_model.sh transf $id 8 | fi 9 | 10 | start_from="--start_from "$ckpt_path 11 | 12 | python train.py --id $id --caption_model transformer --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 30 --structure_after 14 --structure_sample_n 1 --structure_loss_weight 1 --structure_loss_type new_policy_gradient --self_critical_reward_weight $1 --eval_oracle 0 --sample_n 5 --sample_n_method sample 13 | 14 | #pid=$! 15 | #echo $pid 16 | #sleep 30 17 | #kill -2 ${pid} 18 | -------------------------------------------------------------------------------- /projects/NewSelfCritical/README.md: -------------------------------------------------------------------------------- 1 | # A Better Variant of Self-Critical Sequence Training [[arxiv]](http://arxiv.org/abs/2003.09971) 2 | 3 | ## Abstract 4 | 5 | In this work, we present a simple yet better variant of Self-Critical Sequence Training. We make a simple change in the choice of baseline function in REINFORCE algorithm. The new baseline can bring better performance with no extra cost, compared to the greedy decoding baseline. 6 | 7 | ## Intro 8 | 9 | This "new self critical" is borrowed from "Variational inference for monte carlo objectives". The only difference from the original self critical, is the definition of baseline. 10 | 11 | In the original self critical, the baseline is the score of greedy decoding output. In new self critical, the baseline is the average score of the other samples (this requires the model to generate multiple samples for each image). 12 | 13 | To try "new self critical" on updown model, you can run 14 | 15 | `python train.py --cfg configs/updown_nsc.yml` 16 | 17 | This yml file can also provides you some hint what to change to use new self critical. 18 | 19 | ## My 2 cents 20 | 21 | From my experience, this new self critical always works better than SCST. So don't hesitate to use it. 22 | 23 | Recent paper meshed-memory-transformer also uses such baseline (their formulation is slightly different from mine, but mathematically they are equivalent). The difference is they use beam search during training instead of sampling; this is following Topdown bottomup paper. However, based on my experiments on both their codebase and my codebase, sampling is better than beam search during training. 24 | 25 | (And also, by the way, if using beam search, average reward is not a valid anymore because it's dependent on the samples.) 26 | 27 | ## Reference 28 | If you find this work helpful, please cite this paper: 29 | 30 | ``` 31 | @article{luo2020better, 32 | title={A Better Variant of Self-Critical Sequence Training}, 33 | author={Luo, Ruotian}, 34 | journal={arXiv preprint arXiv:2003.09971}, 35 | year={2020} 36 | } 37 | ``` 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /scripts/build_bpe_subword_nmt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 15 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 16 | first and last indices (in range 1..M) of labels for each image 17 | /label_length stores the length of the sequence for each of the M sequences 18 | 19 | The json file has a dict that contains: 20 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 21 | - an 'images' field that is a list holding auxiliary information for each image, 22 | such as in particular the 'split' it was assigned to. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import json 31 | import argparse 32 | from random import shuffle, seed 33 | import string 34 | # non-standard dependencies: 35 | import h5py 36 | import numpy as np 37 | import torch 38 | import torchvision.models as models 39 | import skimage.io 40 | from PIL import Image 41 | 42 | import codecs 43 | import tempfile 44 | from subword_nmt import learn_bpe, apply_bpe 45 | 46 | # python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe 47 | 48 | def build_vocab(imgs, params): 49 | # count up the number of words 50 | captions = [] 51 | for img in imgs: 52 | for sent in img['sentences']: 53 | captions.append(' '.join(sent['tokens'])) 54 | captions='\n'.join(captions) 55 | all_captions = tempfile.NamedTemporaryFile(delete=False) 56 | all_captions.close() 57 | with open(all_captions.name, 'w') as txt_file: 58 | txt_file.write(captions) 59 | 60 | # 61 | codecs_output = tempfile.NamedTemporaryFile(delete=False) 62 | codecs_output.close() 63 | with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output: 64 | learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count']) 65 | 66 | with codecs.open(codecs_output.name, encoding='UTF-8') as codes: 67 | bpe = apply_bpe.BPE(codes) 68 | 69 | tmp = tempfile.NamedTemporaryFile(delete=False) 70 | tmp.close() 71 | 72 | tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') 73 | 74 | for _, img in enumerate(imgs): 75 | img['final_captions'] = [] 76 | for sent in img['sentences']: 77 | txt = ' '.join(sent['tokens']) 78 | txt = bpe.segment(txt).strip() 79 | img['final_captions'].append(txt.split(' ')) 80 | tmpout.write(txt) 81 | tmpout.write('\n') 82 | if _ < 20: 83 | print(txt) 84 | 85 | tmpout.close() 86 | tmpin = codecs.open(tmp.name, encoding='UTF-8') 87 | 88 | vocab = learn_bpe.get_vocabulary(tmpin) 89 | vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True) 90 | 91 | # Always insert UNK 92 | print('inserting the special UNK token') 93 | vocab.append('UNK') 94 | 95 | print('Vocab size:', len(vocab)) 96 | 97 | os.remove(all_captions.name) 98 | with open(codecs_output.name, 'r') as codes: 99 | bpe = codes.read() 100 | os.remove(codecs_output.name) 101 | os.remove(tmp.name) 102 | 103 | return vocab, bpe 104 | 105 | def encode_captions(imgs, params, wtoi): 106 | """ 107 | encode all captions into one large array, which will be 1-indexed. 108 | also produces label_start_ix and label_end_ix which store 1-indexed 109 | and inclusive (Lua-style) pointers to the first and last caption for 110 | each image in the dataset. 111 | """ 112 | 113 | max_length = params['max_length'] 114 | N = len(imgs) 115 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 116 | 117 | label_arrays = [] 118 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 119 | label_end_ix = np.zeros(N, dtype='uint32') 120 | label_length = np.zeros(M, dtype='uint32') 121 | caption_counter = 0 122 | counter = 1 123 | for i,img in enumerate(imgs): 124 | n = len(img['final_captions']) 125 | assert n > 0, 'error: some image has no captions' 126 | 127 | Li = np.zeros((n, max_length), dtype='uint32') 128 | for j,s in enumerate(img['final_captions']): 129 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 130 | caption_counter += 1 131 | for k,w in enumerate(s): 132 | if k < max_length: 133 | Li[j,k] = wtoi[w] 134 | 135 | # note: word indices are 1-indexed, and captions are padded with zeros 136 | label_arrays.append(Li) 137 | label_start_ix[i] = counter 138 | label_end_ix[i] = counter + n - 1 139 | 140 | counter += n 141 | 142 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 143 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 144 | assert np.all(label_length > 0), 'error: some caption had no words?' 145 | 146 | print('encoded captions to array of size ', L.shape) 147 | return L, label_start_ix, label_end_ix, label_length 148 | 149 | def main(params): 150 | 151 | imgs = json.load(open(params['input_json'], 'r')) 152 | imgs = imgs['images'] 153 | 154 | seed(123) # make reproducible 155 | 156 | # create the vocab 157 | vocab, bpe = build_vocab(imgs, params) 158 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 159 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 160 | 161 | # encode captions in large arrays, ready to ship to hdf5 file 162 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 163 | 164 | # create output h5 file 165 | N = len(imgs) 166 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 167 | f_lb.create_dataset("labels", dtype='uint32', data=L) 168 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 169 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 170 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 171 | f_lb.close() 172 | 173 | # create output json file 174 | out = {} 175 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 176 | out['images'] = [] 177 | out['bpe'] = bpe 178 | for i,img in enumerate(imgs): 179 | 180 | jimg = {} 181 | jimg['split'] = img['split'] 182 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 183 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 184 | 185 | if params['images_root'] != '': 186 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 187 | jimg['width'], jimg['height'] = _img.size 188 | 189 | out['images'].append(jimg) 190 | 191 | json.dump(out, open(params['output_json'], 'w')) 192 | print('wrote ', params['output_json']) 193 | 194 | if __name__ == "__main__": 195 | 196 | parser = argparse.ArgumentParser() 197 | 198 | # input json 199 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 200 | parser.add_argument('--output_json', default='data.json', help='output json file') 201 | parser.add_argument('--output_h5', default='data', help='output h5 file') 202 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 203 | 204 | # options 205 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 206 | parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab') 207 | 208 | args = parser.parse_args() 209 | params = vars(args) # convert to ordinary dict 210 | print('parsed input parameters:') 211 | print(json.dumps(params, indent = 2)) 212 | main(params) 213 | 214 | 215 | -------------------------------------------------------------------------------- /scripts/copy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ ! -d log_$2 ]; then 4 | cp -r log_$1 log_$2 5 | cd log_$2 6 | mv infos_$1-best.pkl infos_$2-best.pkl 7 | mv infos_$1.pkl infos_$2.pkl 8 | cd ../ 9 | fi 10 | -------------------------------------------------------------------------------- /scripts/dump_to_h5df.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import os 4 | import numpy as np 5 | import json 6 | from tqdm import tqdm 7 | 8 | 9 | def main(params): 10 | 11 | imgs = json.load(open(params['input_json'], 'r')) 12 | imgs = imgs['images'] 13 | N = len(imgs) 14 | 15 | if params['fc_input_dir'] is not None: 16 | print('processing fc') 17 | with h5py.File(params['fc_output']) as file_fc: 18 | for i, img in enumerate(tqdm(imgs)): 19 | npy_fc_path = os.path.join( 20 | params['fc_input_dir'], 21 | str(img['cocoid']) + '.npy') 22 | 23 | d_set_fc = file_fc.create_dataset( 24 | str(img['cocoid']), data=np.load(npy_fc_path)) 25 | file_fc.close() 26 | 27 | if params['att_input_dir'] is not None: 28 | print('processing att') 29 | with h5py.File(params['att_output']) as file_att: 30 | for i, img in enumerate(tqdm(imgs)): 31 | npy_att_path = os.path.join( 32 | params['att_input_dir'], 33 | str(img['cocoid']) + '.npz') 34 | 35 | d_set_att = file_att.create_dataset( 36 | str(img['cocoid']), 37 | data=np.load(npy_att_path)['feat']) 38 | file_att.close() 39 | 40 | 41 | if __name__ == "__main__": 42 | 43 | parser = argparse.ArgumentParser() 44 | 45 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 46 | parser.add_argument('--fc_output', default='data', help='output h5 filename for fc') 47 | parser.add_argument('--att_output', default='data', help='output h5 file for att') 48 | parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files') 49 | parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files') 50 | 51 | args = parser.parse_args() 52 | params = vars(args) # convert to ordinary dict 53 | print('parsed input parameters:') 54 | print(json.dumps(params, indent=2)) 55 | 56 | main(params) -------------------------------------------------------------------------------- /scripts/dump_to_lmdb.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/Lyken17/Efficient-PyTorch/tools 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import os.path as osp 9 | import os, sys 10 | import os.path as osp 11 | from PIL import Image 12 | import six 13 | import string 14 | 15 | from lmdbdict import lmdbdict 16 | from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC 17 | import pickle 18 | import tqdm 19 | import numpy as np 20 | import argparse 21 | import json 22 | 23 | import torch 24 | import torch.utils.data as data 25 | from torch.utils.data import DataLoader 26 | 27 | import csv 28 | csv.field_size_limit(sys.maxsize) 29 | FIELDNAMES = ['image_id', 'status'] 30 | 31 | class FolderLMDB(data.Dataset): 32 | def __init__(self, db_path, fn_list=None): 33 | self.db_path = db_path 34 | self.lmdb = lmdbdict(db_path, unsafe=True) 35 | self.lmdb._key_dumps = DUMPS_FUNC['ascii'] 36 | self.lmdb._value_loads = LOADS_FUNC['identity'] 37 | if fn_list is not None: 38 | self.length = len(fn_list) 39 | self.keys = fn_list 40 | else: 41 | raise Error 42 | 43 | def __getitem__(self, index): 44 | byteflow = self.lmdb[self.keys[index]] 45 | 46 | # load image 47 | imgbuf = byteflow 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | if args.extension == '.npz': 53 | feat = np.load(buf)['feat'] 54 | else: 55 | feat = np.load(buf) 56 | except Exception as e: 57 | print(self.keys[index], e) 58 | return None 59 | 60 | return feat 61 | 62 | def __len__(self): 63 | return self.length 64 | 65 | def __repr__(self): 66 | return self.__class__.__name__ + ' (' + self.db_path + ')' 67 | 68 | 69 | def make_dataset(dir, extension): 70 | images = [] 71 | dir = os.path.expanduser(dir) 72 | for root, _, fnames in sorted(os.walk(dir)): 73 | for fname in sorted(fnames): 74 | if has_file_allowed_extension(fname, [extension]): 75 | path = os.path.join(root, fname) 76 | images.append(path) 77 | 78 | return images 79 | 80 | 81 | def raw_reader(path): 82 | with open(path, 'rb') as f: 83 | bin_data = f.read() 84 | return bin_data 85 | 86 | 87 | def raw_npz_reader(path): 88 | with open(path, 'rb') as f: 89 | bin_data = f.read() 90 | try: 91 | npz_data = np.load(six.BytesIO(bin_data))['feat'] 92 | except Exception as e: 93 | print(path) 94 | npz_data = None 95 | print(e) 96 | return bin_data, npz_data 97 | 98 | 99 | def raw_npy_reader(path): 100 | with open(path, 'rb') as f: 101 | bin_data = f.read() 102 | try: 103 | npy_data = np.load(six.BytesIO(bin_data)) 104 | except Exception as e: 105 | print(path) 106 | npy_data = None 107 | print(e) 108 | return bin_data, npy_data 109 | 110 | 111 | class Folder(data.Dataset): 112 | 113 | def __init__(self, root, loader, extension, fn_list=None): 114 | super(Folder, self).__init__() 115 | self.root = root 116 | if fn_list: 117 | samples = [os.path.join(root, str(_)+extension) for _ in fn_list] 118 | else: 119 | samples = make_dataset(self.root, extension) 120 | 121 | self.loader = loader 122 | self.extension = extension 123 | self.samples = samples 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | Returns: 130 | tuple: (sample, target) where target is class_index of the target class. 131 | """ 132 | path = self.samples[index] 133 | sample = self.loader(path) 134 | 135 | return (path.split('/')[-1].split('.')[0],) + sample 136 | 137 | def __len__(self): 138 | return len(self.samples) 139 | 140 | 141 | def folder2lmdb(dpath, fn_list, write_frequency=5000): 142 | directory = osp.expanduser(osp.join(dpath)) 143 | print("Loading dataset from %s" % directory) 144 | if args.extension == '.npz': 145 | dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', 146 | fn_list=fn_list) 147 | else: 148 | dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', 149 | fn_list=fn_list) 150 | data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) 151 | 152 | # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) 153 | lmdb_path = osp.join("%s.lmdb" % (directory)) 154 | isdir = os.path.isdir(lmdb_path) 155 | 156 | print("Generate LMDB to %s" % lmdb_path) 157 | db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') 158 | 159 | tsvfile = open(args.output_file, 'a') 160 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 161 | names = [] 162 | all_keys = [] 163 | for idx, data in enumerate(tqdm.tqdm(data_loader)): 164 | # print(type(data), data) 165 | name, byte, npz = data[0] 166 | if npz is not None: 167 | db[name] = byte 168 | all_keys.append(name) 169 | names.append({'image_id': name, 'status': str(npz is not None)}) 170 | if idx % write_frequency == 0: 171 | print("[%d/%d]" % (idx, len(data_loader))) 172 | print('writing') 173 | db.flush() 174 | # write in tsv 175 | for name in names: 176 | writer.writerow(name) 177 | names = [] 178 | tsvfile.flush() 179 | print('writing finished') 180 | # write all keys 181 | # txn.put("keys".encode(), pickle.dumps(all_keys)) 182 | # # finish iterating through dataset 183 | # txn.commit() 184 | for name in names: 185 | writer.writerow(name) 186 | tsvfile.flush() 187 | tsvfile.close() 188 | 189 | print("Flushing database ...") 190 | db.flush() 191 | del db 192 | 193 | def parse_args(): 194 | """ 195 | Parse input arguments 196 | """ 197 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 198 | # parser.add_argument('--json) 199 | parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) 200 | parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) 201 | parser.add_argument('--folder', default='./data/cocobu_att', type=str) 202 | parser.add_argument('--extension', default='.npz', type=str) 203 | 204 | args = parser.parse_args() 205 | return args 206 | 207 | if __name__ == "__main__": 208 | global args 209 | args = parse_args() 210 | 211 | args.output_file += args.folder.split('/')[-1] 212 | if args.folder.find('/') > 0: 213 | args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file 214 | print(args.output_file) 215 | 216 | img_list = json.load(open(args.input_json, 'r'))['images'] 217 | fn_list = [str(_['cocoid']) for _ in img_list] 218 | found_ids = set() 219 | try: 220 | with open(args.output_file, 'r') as tsvfile: 221 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 222 | for item in reader: 223 | if item['status'] == 'True': 224 | found_ids.add(item['image_id']) 225 | except: 226 | pass 227 | fn_list = [_ for _ in fn_list if _ not in found_ids] 228 | folder2lmdb(args.folder, fn_list) 229 | 230 | # Test existing. 231 | found_ids = set() 232 | with open(args.output_file, 'r') as tsvfile: 233 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 234 | for item in reader: 235 | if item['status'] == 'True': 236 | found_ids.add(item['image_id']) 237 | 238 | folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) 239 | data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) 240 | for data in tqdm.tqdm(data_loader): 241 | assert data[0] is not None -------------------------------------------------------------------------------- /scripts/make_bu_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import base64 7 | import numpy as np 8 | import csv 9 | import sys 10 | import zlib 11 | import time 12 | import mmap 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # output_dir 18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') 19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') 20 | 21 | args = parser.parse_args() 22 | 23 | csv.field_size_limit(sys.maxsize) 24 | 25 | 26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', 28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ 29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ 30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] 31 | 32 | os.makedirs(args.output_dir+'_att') 33 | os.makedirs(args.output_dir+'_fc') 34 | os.makedirs(args.output_dir+'_box') 35 | 36 | for infile in infiles: 37 | print('Reading ' + infile) 38 | with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file: 39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 40 | for item in reader: 41 | item['image_id'] = int(item['image_id']) 42 | item['num_boxes'] = int(item['num_boxes']) 43 | for field in ['boxes', 'features']: 44 | item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')), 45 | dtype=np.float32).reshape((item['num_boxes'],-1)) 46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) 47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) 48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into features files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: two folders of features 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import argparse 22 | from random import shuffle, seed 23 | import string 24 | # non-standard dependencies: 25 | import h5py 26 | from six.moves import cPickle 27 | import numpy as np 28 | import torch 29 | import torchvision.models as models 30 | import skimage.io 31 | 32 | from torchvision import transforms as trn 33 | preprocess = trn.Compose([ 34 | #trn.ToTensor(), 35 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 36 | ]) 37 | 38 | from captioning.utils.resnet_utils import myResnet 39 | import captioning.utils.resnet as resnet 40 | 41 | 42 | def main(params): 43 | net = getattr(resnet, params['model'])() 44 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 45 | my_resnet = myResnet(net) 46 | my_resnet.cuda() 47 | my_resnet.eval() 48 | 49 | imgs = json.load(open(params['input_json'], 'r')) 50 | imgs = imgs['images'] 51 | N = len(imgs) 52 | 53 | seed(123) # make reproducible 54 | 55 | dir_fc = params['output_dir']+'_fc' 56 | dir_att = params['output_dir']+'_att' 57 | if not os.path.isdir(dir_fc): 58 | os.mkdir(dir_fc) 59 | if not os.path.isdir(dir_att): 60 | os.mkdir(dir_att) 61 | 62 | for i,img in enumerate(imgs): 63 | # load the image 64 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 65 | # handle grayscale input images 66 | if len(I.shape) == 2: 67 | I = I[:,:,np.newaxis] 68 | I = np.concatenate((I,I,I), axis=2) 69 | 70 | I = I.astype('float32')/255.0 71 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 72 | I = preprocess(I) 73 | with torch.no_grad(): 74 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 75 | # write to pkl 76 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 77 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 78 | 79 | if i % 1000 == 0: 80 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 81 | print('wrote ', params['output_dir']) 82 | 83 | if __name__ == "__main__": 84 | 85 | parser = argparse.ArgumentParser() 86 | 87 | # input json 88 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 89 | parser.add_argument('--output_dir', default='data', help='output h5 file') 90 | 91 | # options 92 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 93 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 94 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 95 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 96 | 97 | args = parser.parse_args() 98 | params = vars(args) # convert to ordinary dict 99 | print('parsed input parameters:') 100 | print(json.dumps(params, indent = 2)) 101 | main(params) 102 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.py 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 15 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 16 | first and last indices (in range 1..M) of labels for each image 17 | /label_length stores the length of the sequence for each of the M sequences 18 | 19 | The json file has a dict that contains: 20 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 21 | - an 'images' field that is a list holding auxiliary information for each image, 22 | such as in particular the 'split' it was assigned to. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import json 31 | import argparse 32 | from random import shuffle, seed 33 | import string 34 | # non-standard dependencies: 35 | import h5py 36 | import numpy as np 37 | import torch 38 | import torchvision.models as models 39 | import skimage.io 40 | from PIL import Image 41 | 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | 96 | def encode_captions(imgs, params, wtoi): 97 | """ 98 | encode all captions into one large array, which will be 1-indexed. 99 | also produces label_start_ix and label_end_ix which store 1-indexed 100 | and inclusive (Lua-style) pointers to the first and last caption for 101 | each image in the dataset. 102 | """ 103 | 104 | max_length = params['max_length'] 105 | N = len(imgs) 106 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 107 | 108 | label_arrays = [] 109 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 110 | label_end_ix = np.zeros(N, dtype='uint32') 111 | label_length = np.zeros(M, dtype='uint32') 112 | caption_counter = 0 113 | counter = 1 114 | for i,img in enumerate(imgs): 115 | n = len(img['final_captions']) 116 | assert n > 0, 'error: some image has no captions' 117 | 118 | Li = np.zeros((n, max_length), dtype='uint32') 119 | for j,s in enumerate(img['final_captions']): 120 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 121 | caption_counter += 1 122 | for k,w in enumerate(s): 123 | if k < max_length: 124 | Li[j,k] = wtoi[w] 125 | 126 | # note: word indices are 1-indexed, and captions are padded with zeros 127 | label_arrays.append(Li) 128 | label_start_ix[i] = counter 129 | label_end_ix[i] = counter + n - 1 130 | 131 | counter += n 132 | 133 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 134 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 135 | assert np.all(label_length > 0), 'error: some caption had no words?' 136 | 137 | print('encoded captions to array of size ', L.shape) 138 | return L, label_start_ix, label_end_ix, label_length 139 | 140 | 141 | def main(params): 142 | 143 | imgs = json.load(open(params['input_json'], 'r')) 144 | imgs = imgs['images'] 145 | 146 | seed(123) # make reproducible 147 | 148 | # create the vocab 149 | vocab = build_vocab(imgs, params) 150 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 151 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 152 | 153 | # encode captions in large arrays, ready to ship to hdf5 file 154 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 155 | 156 | # create output h5 file 157 | N = len(imgs) 158 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 159 | f_lb.create_dataset("labels", dtype='uint32', data=L) 160 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 161 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 162 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 163 | f_lb.close() 164 | 165 | # create output json file 166 | out = {} 167 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 168 | out['images'] = [] 169 | for i,img in enumerate(imgs): 170 | 171 | jimg = {} 172 | jimg['split'] = img['split'] 173 | if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need 174 | if 'cocoid' in img: 175 | jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 176 | elif 'imgid' in img: 177 | jimg['id'] = img['imgid'] 178 | 179 | if params['images_root'] != '': 180 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 181 | jimg['width'], jimg['height'] = _img.size 182 | 183 | out['images'].append(jimg) 184 | 185 | json.dump(out, open(params['output_json'], 'w')) 186 | print('wrote ', params['output_json']) 187 | 188 | if __name__ == "__main__": 189 | 190 | parser = argparse.ArgumentParser() 191 | 192 | # input json 193 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 194 | parser.add_argument('--output_json', default='data.json', help='output json file') 195 | parser.add_argument('--output_h5', default='data', help='output h5 file') 196 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 197 | 198 | # options 199 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 200 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 201 | 202 | args = parser.parse_args() 203 | params = vars(args) # convert to ordinary dict 204 | print('parsed input parameters:') 205 | print(json.dumps(params, indent = 2)) 206 | main(params) 207 | -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Precompute ngram counts of captions, to accelerate cider computation during training time. 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | from six.moves import cPickle 9 | import captioning.utils.misc as utils 10 | from collections import defaultdict 11 | 12 | import sys 13 | sys.path.append("cider") 14 | from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer 15 | 16 | 17 | def get_doc_freq(refs, params): 18 | tmp = CiderScorer(df_mode="corpus") 19 | for ref in refs: 20 | tmp.cook_append(None, ref) 21 | tmp.compute_doc_freq() 22 | return tmp.document_frequency, len(tmp.crefs) 23 | 24 | 25 | def build_dict(imgs, wtoi, params): 26 | wtoi[''] = 0 27 | 28 | count_imgs = 0 29 | 30 | refs_words = [] 31 | refs_idxs = [] 32 | for img in imgs: 33 | if (params['split'] == img['split']) or \ 34 | (params['split'] == 'train' and img['split'] == 'restval') or \ 35 | (params['split'] == 'all'): 36 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 37 | ref_words = [] 38 | ref_idxs = [] 39 | for sent in img['sentences']: 40 | if hasattr(params, 'bpe'): 41 | sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') 42 | tmp_tokens = sent['tokens'] + [''] 43 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 44 | ref_words.append(' '.join(tmp_tokens)) 45 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 46 | refs_words.append(ref_words) 47 | refs_idxs.append(ref_idxs) 48 | count_imgs += 1 49 | print('total imgs:', count_imgs) 50 | 51 | ngram_words, count_refs = get_doc_freq(refs_words, params) 52 | ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) 53 | print('count_refs:', count_refs) 54 | return ngram_words, ngram_idxs, count_refs 55 | 56 | def main(params): 57 | 58 | imgs = json.load(open(params['input_json'], 'r')) 59 | dict_json = json.load(open(params['dict_json'], 'r')) 60 | itow = dict_json['ix_to_word'] 61 | wtoi = {w:i for i,w in itow.items()} 62 | 63 | # Load bpe 64 | if 'bpe' in dict_json: 65 | import tempfile 66 | import codecs 67 | codes_f = tempfile.NamedTemporaryFile(delete=False) 68 | codes_f.close() 69 | with open(codes_f.name, 'w') as f: 70 | f.write(dict_json['bpe']) 71 | with codecs.open(codes_f.name, encoding='UTF-8') as codes: 72 | bpe = apply_bpe.BPE(codes) 73 | params.bpe = bpe 74 | 75 | imgs = imgs['images'] 76 | 77 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 78 | 79 | utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) 80 | utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) 81 | 82 | if __name__ == "__main__": 83 | 84 | parser = argparse.ArgumentParser() 85 | 86 | # input json 87 | parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') 88 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') 89 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') 90 | parser.add_argument('--split', default='all', help='test, val, train, all') 91 | args = parser.parse_args() 92 | params = vars(args) # convert to ordinary dict 93 | 94 | main(params) 95 | -------------------------------------------------------------------------------- /scripts/prepro_reference_json.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Create a reference json file used for evaluation with `coco-caption` repo. 4 | Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test) 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import json 13 | import argparse 14 | import sys 15 | import hashlib 16 | from random import shuffle, seed 17 | 18 | 19 | def main(params): 20 | 21 | imgs = json.load(open(params['input_json'][0], 'r'))['images'] 22 | # tmp = [] 23 | # for k in imgs.keys(): 24 | # for img in imgs[k]: 25 | # img['filename'] = img['image_id'] # k+'/'+img['image_id'] 26 | # img['image_id'] = int( 27 | # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint) 28 | # tmp.append(img) 29 | # imgs = tmp 30 | 31 | # create output json file 32 | out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'} 33 | out.update({'images': [], 'annotations': []}) 34 | 35 | cnt = 0 36 | empty_cnt = 0 37 | for i, img in enumerate(imgs): 38 | if img['split'] == 'train': 39 | continue 40 | out['images'].append( 41 | {'id': img.get('cocoid', img['imgid'])}) 42 | for j, s in enumerate(img['sentences']): 43 | if len(s) == 0: 44 | continue 45 | s = ' '.join(s['tokens']) 46 | out['annotations'].append( 47 | {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt}) 48 | cnt += 1 49 | 50 | json.dump(out, open(params['output_json'], 'w')) 51 | print('wrote ', params['output_json']) 52 | 53 | 54 | if __name__ == "__main__": 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | # input json 59 | parser.add_argument('--input_json', nargs='+', required=True, 60 | help='input json file to process into hdf5') 61 | parser.add_argument('--output_json', default='data.json', 62 | help='output json file') 63 | 64 | args = parser.parse_args() 65 | params = vars(args) # convert to ordinary dict 66 | print('parsed input parameters:') 67 | print(json.dumps(params, indent=2)) 68 | main(params) 69 | 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="captioning", 5 | version="0.0.1", 6 | author="Ruotian Luo", 7 | author_email="rluo@ttic.edu", 8 | packages=setuptools.find_packages(), 9 | ) -------------------------------------------------------------------------------- /test/test_pth_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from captioning.data.pth_loader import CaptionDataset 3 | from captioning.utils.misc import pickle_load 4 | 5 | def test_folder(): 6 | x = pickle_load(open('log_trans/infos_trans.pkl', 'rb')) 7 | dataset = CaptionDataset(x['opt']) 8 | ds = torch.utils.data.Subset(dataset, dataset.split_ix['train']) 9 | ds[0] 10 | 11 | def test_lmdb(): 12 | x = pickle_load(open('log_trans/infos_trans.pkl', 'rb')) 13 | x['opt'].input_att_dir = 'data/vilbert_att.lmdb' 14 | dataset = CaptionDataset(x['opt']) 15 | ds = torch.utils.data.Subset(dataset, dataset.split_ix['train']) 16 | ds[0] -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import captioning.utils.opts as opts 13 | import captioning.models as models 14 | from captioning.data.dataloader import * 15 | from captioning.data.dataloaderraw import * 16 | import captioning.utils.eval_utils as eval_utils 17 | import argparse 18 | import captioning.utils.misc as utils 19 | import captioning.modules.losses as losses 20 | import torch 21 | 22 | # Input arguments and options 23 | parser = argparse.ArgumentParser() 24 | # Input paths 25 | parser.add_argument('--model', type=str, default='', 26 | help='path to model to evaluate') 27 | parser.add_argument('--cnn_model', type=str, default='resnet101', 28 | help='resnet101, resnet152') 29 | parser.add_argument('--infos_path', type=str, default='', 30 | help='path to infos to evaluate') 31 | parser.add_argument('--only_lang_eval', type=int, default=0, 32 | help='lang eval on saved results') 33 | parser.add_argument('--force', type=int, default=0, 34 | help='force to evaluate no matter if there are results available') 35 | parser.add_argument('--device', type=str, default='cuda', 36 | help='cpu or cuda') 37 | opts.add_eval_options(parser) 38 | opts.add_diversity_opts(parser) 39 | opt = parser.parse_args() 40 | 41 | # Load infos 42 | with open(opt.infos_path, 'rb') as f: 43 | infos = utils.pickle_load(f) 44 | 45 | # override and collect parameters 46 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] 47 | ignore = ['start_from'] 48 | 49 | for k in vars(infos['opt']).keys(): 50 | if k in replace: 51 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) 52 | elif k not in ignore: 53 | if not k in vars(opt): 54 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 55 | 56 | vocab = infos['vocab'] # ix -> word mapping 57 | 58 | pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') 59 | result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') 60 | 61 | if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): 62 | # if results existed, then skip, unless force is on 63 | if not opt.force: 64 | try: 65 | if os.path.isfile(result_fn): 66 | print(result_fn) 67 | json.load(open(result_fn, 'r')) 68 | print('already evaluated') 69 | os._exit(0) 70 | except: 71 | pass 72 | 73 | predictions, n_predictions = torch.load(pred_fn) 74 | lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) 75 | print(lang_stats) 76 | os._exit(0) 77 | 78 | # At this point only_lang_eval if 0 79 | if not opt.force: 80 | # Check out if 81 | try: 82 | # if no pred exists, then continue 83 | tmp = torch.load(pred_fn) 84 | # if language_eval == 1, and no pred exists, then continue 85 | if opt.language_eval == 1: 86 | json.load(open(result_fn, 'r')) 87 | print('Result is already there') 88 | os._exit(0) 89 | except: 90 | pass 91 | 92 | # Setup the model 93 | opt.vocab = vocab 94 | model = models.setup(opt) 95 | del opt.vocab 96 | model.load_state_dict(torch.load(opt.model, map_location='cpu')) 97 | model.to(opt.device) 98 | model.eval() 99 | crit = losses.LanguageModelCriterion() 100 | 101 | # Create the Data Loader instance 102 | if len(opt.image_folder) == 0: 103 | loader = DataLoader(opt) 104 | else: 105 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 106 | 'coco_json': opt.coco_json, 107 | 'batch_size': opt.batch_size, 108 | 'cnn_model': opt.cnn_model}) 109 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 110 | # So make sure to use the vocab in infos file. 111 | loader.dataset.ix_to_word = infos['vocab'] 112 | 113 | 114 | # Set sample options 115 | opt.dataset = opt.input_json 116 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 117 | vars(opt)) 118 | 119 | print('loss: ', loss) 120 | if lang_stats: 121 | print(lang_stats) 122 | 123 | if opt.dump_json == 1: 124 | # dump the json 125 | json.dump(split_predictions, open('vis/vis.json', 'w')) 126 | -------------------------------------------------------------------------------- /tools/eval_ensemble.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import captioning.utils.opts as opts 13 | import captioning.models as models 14 | from captioning.data.dataloader import * 15 | from captioning.data.dataloaderraw import * 16 | import captioning.utils.eval_utils as eval_utils 17 | import argparse 18 | import captioning.utils.misc as utils 19 | import captioning.modules.losses as losses 20 | import torch 21 | 22 | # Input arguments and options 23 | parser = argparse.ArgumentParser() 24 | # Input paths 25 | parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble') 26 | parser.add_argument('--weights', nargs='+', required=False, default=None, help='id of the models to ensemble') 27 | # parser.add_argument('--models', nargs='+', required=True 28 | # help='path to model to evaluate') 29 | # parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate') 30 | opts.add_eval_options(parser) 31 | opts.add_diversity_opts(parser) 32 | 33 | opt = parser.parse_args() 34 | 35 | model_infos = [] 36 | model_paths = [] 37 | for id in opt.ids: 38 | if '-' in id: 39 | id, app = id.split('-') 40 | app = '-'+app 41 | else: 42 | app = '' 43 | model_infos.append(utils.pickle_load(open('log_%s/infos_%s%s.pkl' %(id, id, app), 'rb'))) 44 | model_paths.append('log_%s/model%s.pth' %(id,app)) 45 | 46 | # Load one infos 47 | infos = model_infos[0] 48 | 49 | # override and collect parameters 50 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] 51 | for k in replace: 52 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) 53 | 54 | vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model 55 | 56 | 57 | opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos]) 58 | assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat' 59 | assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat' 60 | 61 | vocab = infos['vocab'] # ix -> word mapping 62 | 63 | # Setup the model 64 | from models.AttEnsemble import AttEnsemble 65 | 66 | _models = [] 67 | for i in range(len(model_infos)): 68 | model_infos[i]['opt'].start_from = None 69 | model_infos[i]['opt'].vocab = vocab 70 | tmp = models.setup(model_infos[i]['opt']) 71 | tmp.load_state_dict(torch.load(model_paths[i])) 72 | _models.append(tmp) 73 | 74 | if opt.weights is not None: 75 | opt.weights = [float(_) for _ in opt.weights] 76 | model = AttEnsemble(_models, weights=opt.weights) 77 | model.seq_length = opt.max_length 78 | model.cuda() 79 | model.eval() 80 | crit = losses.LanguageModelCriterion() 81 | 82 | # Create the Data Loader instance 83 | if len(opt.image_folder) == 0: 84 | loader = DataLoader(opt) 85 | else: 86 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 87 | 'coco_json': opt.coco_json, 88 | 'batch_size': opt.batch_size, 89 | 'cnn_model': opt.cnn_model}) 90 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 91 | # So make sure to use the vocab in infos file. 92 | loader.ix_to_word = infos['vocab'] 93 | 94 | opt.id = '+'.join([_+str(__) for _,__ in zip(opt.ids, opt.weights)]) 95 | # Set sample options 96 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 97 | vars(opt)) 98 | 99 | print('loss: ', loss) 100 | if lang_stats: 101 | print(lang_stats) 102 | 103 | if opt.dump_json == 1: 104 | # dump the json 105 | json.dump(split_predictions, open('vis/vis.json', 'w')) 106 | -------------------------------------------------------------------------------- /vis/imgs/dummy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/ImageCaptioning.pytorch/4c48a3304932d58c5349434e7b0085f48dcb4be4/vis/imgs/dummy -------------------------------------------------------------------------------- /vis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | neuraltalk2 results visualization 7 | 8 | 42 | 43 | 44 |
45 | 72 | 73 | 74 | --------------------------------------------------------------------------------