├── .gitignore ├── README.md ├── data ├── __init__.py ├── coco_dataloader.py ├── coco_dataset.py └── transparent_data_loader.py ├── data_generator.py ├── demo.py ├── demo_material ├── __init__.py ├── cat_girl.jpg ├── demo_coco_tokens.pickle ├── micheal.jpg ├── napoleon.jpg └── tatin.jpg ├── demo_results.png ├── eval ├── __init__.py ├── bleu │ ├── LICENSE │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── cider.py │ ├── cider_scorer.py │ ├── reinforce_cider.py │ └── reinforce_cider_scorer.py ├── eval.py ├── get_stanford_models.sh ├── meteor │ ├── __init__.py │ ├── data │ │ └── paraphrase-en.gz │ ├── meteor-1.5.jar │ └── meteor.py ├── rouge │ ├── __init__.py │ └── rouge.py ├── spice │ ├── __init__.py │ ├── cache │ │ └── .gitkeep │ ├── lib │ │ ├── Meteor-1.5.jar │ │ ├── SceneGraphParser-1.0.jar │ │ ├── ejml-0.23.jar │ │ ├── fst-2.47.jar │ │ ├── guava-19.0.jar │ │ ├── hamcrest-core-1.3.jar │ │ ├── jackson-core-2.5.3.jar │ │ ├── javassist-3.19.0-GA.jar │ │ ├── json-simple-1.1.1.jar │ │ ├── junit-4.12.jar │ │ ├── lmdbjni-0.4.6.jar │ │ ├── lmdbjni-linux64-0.4.6.jar │ │ ├── lmdbjni-osx64-0.4.6.jar │ │ ├── lmdbjni-win64-0.4.6.jar │ │ ├── objenesis-2.4.jar │ │ ├── slf4j-api-1.7.12.jar │ │ └── slf4j-simple-1.7.21.jar │ ├── spice-1.0.jar │ ├── spice.py │ └── tmp │ │ └── .gitkeep └── tokenizer │ ├── __init__.py │ ├── ptbtokenizer.py │ └── stanford-corenlp-3.4.1.jar ├── github_ignore_material ├── raw_data │ └── .gitkeep └── saves │ └── .gitkeep ├── license.txt ├── losses ├── __init__.py ├── loss.py └── reward.py ├── models ├── End_ExpansionNet_v2.py ├── ExpansionNet_v2.py ├── __init__.py ├── captioning_model.py ├── ensemble_captioning_model.py ├── layers.py └── swin_transformer_mod.py ├── onnx4tensorrt ├── End_ExpansionNet_v2_onnx_tensorrt.py ├── __init__.py ├── convert2onnx.py ├── onnx2tensorrt.py └── swin_transformer_onnx_tensorrt.py ├── optims └── radam.py ├── requirements.txt ├── requirements_wTensorRT.txt ├── results_image.png ├── test.py ├── train.py └── utils ├── args_utils.py ├── image_utils.py ├── language_utils.py ├── masking.py └── saving_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /github_ignore_material/raw_data/* 2 | !/github_ignore_material/raw_data/.gitkeep 3 | /github_ignore_material/saves/* 4 | !/github_ignore_material/saves/.gitkeep 5 | eval/spice/lib/stanford-corenlp-3.6.0.jar 6 | eval/spice/lib/stanford-corenlp-3.6.0-models.jar 7 | .idea/* 8 | **/__pycache__/* 9 | **/tmp/* 10 | !/eval/spice/tmp/.gitkeep 11 | **/cache/* 12 | !/eval/spice/cache/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ExpansionNet v2: Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning 2 | 3 | Implementation code for "[Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning](https://www.computer.org/csdl/proceedings-article/bigdata/2023/10386812/1TUPyooQsnu)" [ [BigData2023](https://www.computer.org/csdl/proceedings-article/bigdata/2023/10386812/1TUPyooQsnu) ] 4 | [ [Arxiv](https://arxiv.org/abs/2208.06551) ], previously entitled as "ExpansionNet v2: Block Static Expansion 5 | in fast end to end training for Image Captioning".
6 | 7 | ## Demo 8 | 9 | You can test the model on generic images (not included in COCO) downloading 10 | the checkpoint [here](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing) 11 | and launching the script `demo.py`: 12 | ``` 13 | python demo.py \ 14 | --load_path your_download_folder/rf_model.pth \ 15 | --image_paths your_image_path/image_1 your_image_path/image_2 ... 16 | 17 | ``` 18 | Some examples: 19 | 20 |

21 | 22 |

23 | 24 | images are available in `demo_material`. 25 | 26 | ## Results 27 | 28 | [SacreEOS](https://github.com/jchenghu/sacreeos) Signature: `STANDARDwInit+Cider-D[n4,s6.0]+average[nspi5]+1.0.0`.
29 | Results are artifacts-free. 30 | 31 | Online evaluation server results: 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
CaptionsB1B2B3B4MeteorRouge-LCIDEr-D
c4096.992.685.075.340.176.4140.8
c583.368.854.442.130.460.8138.5
64 | 65 | Results on the Karpathy test split: 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 |
ModelB@1B@4MeteorRouge-LCIDEr-DSpice
Ensemble83.542.730.661.1143.724.7
Single82.841.530.360.5140.424.5
95 | 96 | Predictions examples: 97 | 98 |

99 | 100 |

101 | 102 | 103 | ## ONNX & TensorRT 104 | 105 | The model supports now ONNX conversion and deployment with TensorRT. 106 | The graph can be generated using `onnx4tensorrt/convert2onnx.py`. 107 | Its execution mainly requires the `onnx` package but the `onnx_runtime` and `onnx_tensorrt` packages are 108 | optionally used for testing purposes (see `convert2onnx.py` arguments). 109 | 110 | Assuming Generic conversion commands: 111 | ``` 112 | python onnx4tensorrt/convert2onnx.py --onnx_simplify true --load_model_path &> output_onnx.txt & 113 | python onnx4tensorrt/onnx2tensorrt.py &> output_tensorrt.txt & 114 | ``` 115 | Currently working only in FP32. 116 | 117 | 118 | ## Training 119 | 120 | In this guide we cover all the training steps reported in the paper and 121 | provide the commands to reproduce our work. 122 | 123 | #### Requirements 124 | 125 | * python >= 3.7 126 | * numpy 127 | * Java 1.8.0 128 | * torch 129 | * torchvision 130 | * h5py 131 | 132 | Installing whatever version of `torch, torchvision, h5py, Pillow` fit your machine should 133 | work in most cases. 134 | 135 | One instance of requirements file can be found in `requirements.txt`, in case also TensorRT is needed 136 | use `requirements_wTensorRT.txt`. However they represent one working instance, specific versions of each package 137 | might not be required. 138 | 139 | #### Data preparation 140 | 141 | MS-COCO 2014 images can be downloaded [here](https://cocodataset.org/#download), 142 | the respective captions are uploaded in our online [drive](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing) 143 | and the backbone can be found [here](https://github.com/microsoft/Swin-Transformer). All files, in particular 144 | the `dataset_coco.json` file and the backbone are suggested to be moved in `github_ignore_materal/raw_data/` since commands provided 145 | in the following steps assume these files are placed in that directory. 146 | 147 | 148 | #### Premises 149 | 150 | For the sake of transparency (at the cost of possibly being overly verbose) 151 | the complete commands are shown below, but only few arguments deserve a little 152 | bit of care for the reproduction of our work while most of them are automatically handled. 153 | 154 | Logs are stored in `output_file.txt`, which is continuously updated 155 | until the process is complete (in Linux it may be handy the command `watch -n 1 tail -n 30 output_file.txt`). It is overwritten in each training phase, thus, 156 | before moving to the next one, make sure to save or make a copy if needed. 157 | 158 | Lastly, in some configurations the batch size may look different compared to the one 159 | reported in the paper when argument `num_accum` is specified (default is 1). This is only 160 | a visual subtlety, which means that gradient accumulation is performed in order to satisfy 161 | the memory constraints of 40GB RAM of a single GPU. 162 | 163 | #### 1. Cross Entropy Training: Features generation 164 | 165 | First we generate the features for the first training step: 166 | ``` 167 | cd ExpansionNet_v2_src 168 | python data_generator.py \ 169 | --save_model_path ./github_ignore_material/raw_data/swin_large_patch4_window12_384_22k.pth \ 170 | --output_path ./github_ignore_material/raw_data/features.hdf5 \ 171 | --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ \ 172 | --captions_path ./github_ignore_material/raw_data/ &> output_file.txt & 173 | ``` 174 | Even if it's suggested not to do so, the `output_path` argument can be replaced with the desired destination (this would require 175 | changing the argument `features_path` in the next commands as well). Since it's a pretty big 176 | file (102GB), once the first training is completed, it will be automatically overwritten by 177 | the remaining operations in case the default name is unchanged. 178 | 179 | TIPS: if 100GB of memory is too much for your disk, add the option `--dtype fp16` 180 | which saves arrays into FP16 so it requires only 50GB. It shouldn't change affect much the result. 181 | By default, we keep FP32 for conformity with the experimental setup of the paper. 182 | 183 | 184 | #### 2. Cross-Entropy Training: Partial Training 185 | 186 | In this step the model is trained using the Cross Entropy loss and the features generated 187 | in the previous step: 188 | ``` 189 | python train.py --N_enc 3 --N_dec 3 \ 190 | --model_dim 512 --seed 775533 --optim_type radam --sched_type custom_warmup_anneal \ 191 | --warmup 10000 --lr 2e-4 --anneal_coeff 0.8 --anneal_every_epoch 2 --enc_drop 0.3 \ 192 | --dec_drop 0.3 --enc_input_drop 0.3 --dec_input_drop 0.3 --drop_other 0.3 \ 193 | --batch_size 48 --num_accum 1 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [3] \ 194 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \ 195 | --is_end_to_end False --features_path ./github_ignore_material/raw_data/features.hdf5 --partial_load False \ 196 | --print_every_iter 11807 --eval_every_iter 999999 \ 197 | --reinforce False --num_epochs 8 &> output_file.txt & 198 | ``` 199 | 200 | #### 3. Cross-Entropy Training: End to End Training 201 | 202 | The following command trains the entire network in the end to end mode. However, 203 | one argument need to be changed according to the previous result, the 204 | checkpoint name file. Weights are stored in the directory `github_ignore_materal/saves/`, 205 | with the prefix `checkpoint_ ... _xe.pth` we will refer it as `phase2_checkpoint` below and in 206 | the later step: 207 | ``` 208 | python train.py --N_enc 3 --N_dec 3 \ 209 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \ 210 | --warmup 1 --lr 3e-5 --anneal_coeff 0.55 --anneal_every_epoch 1 --enc_drop 0.3 \ 211 | --dec_drop 0.3 --enc_input_drop 0.3 --dec_input_drop 0.3 --drop_other 0.3 \ 212 | --batch_size 16 --num_accum 3 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [3] \ 213 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \ 214 | --is_end_to_end True --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ --partial_load True \ 215 | --backbone_save_path ./github_ignore_material/raw_data/swin_large_patch4_window12_384_22k.pth \ 216 | --body_save_path ./github_ignore_material/saves/phase2_checkpoint \ 217 | --print_every_iter 15000 --eval_every_iter 999999 \ 218 | --reinforce False --num_epochs 2 &> output_file.txt & 219 | ``` 220 | In case you are interested in the network's weights at the end of this stage, 221 | before moving to the self-critical learning, rename the checkpoint file from `checkpoint_ ... _xe.pth` into something 222 | else like `phase3_checkpoint` (make sure to change the prefix) otherwise it will 223 | be overwritten during step 5. 224 | 225 | #### 4. CIDEr optimization: Features generation 226 | 227 | This step generates the features for the reinforcement step: 228 | ``` 229 | python data_generator.py \ 230 | --save_model_path ./github_ignore_material/saves/phase3_checkpoint \ 231 | --output_path ./github_ignore_material/raw_data/features.hdf5 \ 232 | --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ \ 233 | --captions_path ./github_ignore_material/raw_data/ &> output_file.txt & 234 | ``` 235 | 236 | #### 5. CIDEr optimization: Partial Training 237 | 238 | The following command performs the partial training using the self-critical learning: 239 | ``` 240 | python train.py --N_enc 3 --N_dec 3 \ 241 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \ 242 | --warmup 1 --lr 1e-4 --anneal_coeff 0.8 --anneal_every_epoch 1 --enc_drop 0.1 \ 243 | --dec_drop 0.1 --enc_input_drop 0.1 --dec_input_drop 0.1 --drop_other 0.1 \ 244 | --batch_size 24 --num_accum 2 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [5] \ 245 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \ 246 | --is_end_to_end False --partial_load True \ 247 | --features_path ./github_ignore_material/raw_data/features.hdf5 \ 248 | --body_save_path ./github_ignore_material/saves/phase3_checkpoint.pth \ 249 | --print_every_iter 4000 --eval_every_iter 99999 \ 250 | --reinforce True --num_epochs 9 &> output_file.txt & 251 | ``` 252 | We refer to the last checkpoint produced in this step as `phase5_checkpoint`, 253 | it should already achieve around 139.5 CIDEr-D on both Validaton and Test set, however 254 | it can be still improved by a little margin with the following optional step. 255 | 256 | 257 | #### 6. CIDEr optimization: End to End Training 258 | 259 | This last step again train the model in an end to end fashion, however it is optional since it only slightly improves the performances: 260 | ``` 261 | python train.py --N_enc 3 --N_dec 3 \ 262 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \ 263 | --warmup 1 --anneal_coeff 1.0 --lr 2e-6 --enc_drop 0.1 \ 264 | --dec_drop 0.1 --enc_input_drop 0.1 --dec_input_drop 0.1 --drop_other 0.1 \ 265 | --batch_size 24 --num_accum 2 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [5] \ 266 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \ 267 | --is_end_to_end True --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ --partial_load True \ 268 | --backbone_save_path ./github_ignore_material/raw_data/phase3_checkpoint \ 269 | --body_save_path ./github_ignore_material/saves/phase5_checkpoint \ 270 | --print_every_iter 15000 --eval_every_iter 999999 \ 271 | --reinforce True --num_epochs 1 &> output_file.txt & 272 | ``` 273 | 274 | ## Evaluation 275 | 276 | In this section we provide the evaluation scripts. We refer to the 277 | last checkpoint as `phase6_checkpoint`. In case the previous training 278 | procedures have been skipped, 279 | weights of one of the ensemble's model can be found [here](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing). 280 | ``` 281 | python test.py --N_enc 3 --N_dec 3 --model_dim 512 \ 282 | --num_gpus 1 --eval_beam_sizes [5] --is_end_to_end True \ 283 | --eval_parallel_batch_size 4 \ 284 | --images_path ./github_ignore_material/raw_data/ \ 285 | --save_model_path ./github_ignore_material/saves/phase6_checkpoint 286 | ``` 287 | The option `is_end_to_end` can be toggled according to the model's type.
288 | It might be required to give permissions to the file `./eval/get_stanford_models.sh` (e.g. `chmod a+x -R ./eval/` in Linux). 289 | 290 | 291 | ## Citation 292 | 293 | If you find this repository useful, please consider citing (no obligation): 294 | 295 | ``` 296 | @inproceedings{hu2023exploiting, 297 | title={Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning}, 298 | author={Hu, Jia Cheng and Cavicchioli, Roberto and Capotondi, Alessandro}, 299 | booktitle={2023 IEEE International Conference on Big Data (BigData)}, 300 | pages={2173--2182}, 301 | year={2023}, 302 | organization={IEEE Computer Society} 303 | } 304 | ``` 305 | 306 | ## Acknowledgements 307 | 308 | We thank the PyTorch team and the following repositories: 309 | * https://github.com/microsoft/Swin-Transformer 310 | * https://github.com/ruotianluo/ImageCaptioning.pytorch 311 | * https://github.com/tylin/coco-caption 312 | 313 | special thanks to the work of [Yiyu Wang et al](https://arxiv.org/abs/2203.15350). 314 | 315 | We thank the user [@shahizat](https://github.com/shahizat) for the suggestion of ONNX and TensorRT conversions.
316 | We also thank the github users from the Issues section which provided valuable feedbacks, suggestions, 317 | and even found very insidious bugs. 318 | 319 | 320 | 321 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/data/__init__.py -------------------------------------------------------------------------------- /data/coco_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from time import time 3 | from utils import language_utils 4 | 5 | import functools 6 | print = functools.partial(print, flush=True) 7 | 8 | 9 | class CocoDatasetKarpathy: 10 | 11 | TrainSet_ID = 1 12 | ValidationSet_ID = 2 13 | TestSet_ID = 3 14 | 15 | def __init__(self, 16 | images_path, 17 | coco_annotations_path, 18 | precalc_features_hdf5_filepath, 19 | preproc_images_hdf5_filepath=None, 20 | limited_num_train_images=None, 21 | limited_num_val_images=None, 22 | limited_num_test_images=None, 23 | dict_min_occurrences=5, 24 | verbose=True 25 | ): 26 | super(CocoDatasetKarpathy, self).__init__() 27 | 28 | self.use_images_instead_of_features = False 29 | if precalc_features_hdf5_filepath is None or precalc_features_hdf5_filepath == 'None' or \ 30 | precalc_features_hdf5_filepath == 'none' or precalc_features_hdf5_filepath == '': 31 | self.use_images_instead_of_features = True 32 | print("Warning: since no hdf5 path is provided using images instead of pre-calculated features.") 33 | print("Features path: " + str(precalc_features_hdf5_filepath)) 34 | 35 | self.preproc_images_hdf5_filepath = None 36 | if preproc_images_hdf5_filepath is not None: 37 | print("Preprocessed hdf5 file path not None: " + str(preproc_images_hdf5_filepath)) 38 | print("Using preprocessed hdf5 file instead.") 39 | self.preproc_images_hdf5_filepath = preproc_images_hdf5_filepath 40 | 41 | else: 42 | self.precalc_features_hdf5_filepath = precalc_features_hdf5_filepath 43 | print("Features path: " + str(self.precalc_features_hdf5_filepath)) 44 | print("Features path provided, images are provided in form of features.") 45 | 46 | if images_path is None: 47 | self.images_path = "" 48 | else: 49 | self.images_path = images_path 50 | 51 | self.karpathy_train_dict = dict() 52 | self.karpathy_val_dict = dict() 53 | self.karpathy_test_dict = dict() 54 | 55 | with open(coco_annotations_path, 'r') as f: 56 | json_file = json.load(f)['images'] 57 | 58 | if verbose: 59 | print("Initializing dataset... ", end=" ") 60 | for json_item in json_file: 61 | new_item = dict() 62 | 63 | new_item['img_path'] = self.images_path + json_item['filepath'] + '/img/' + json_item['filename'] 64 | 65 | new_item_captions = [item['raw'] for item in json_item['sentences']] 66 | new_item['img_id'] = json_item['cocoid'] 67 | new_item['captions'] = new_item_captions 68 | 69 | if json_item['split'] == 'train' or json_item['split'] == 'restval': 70 | self.karpathy_train_dict[json_item['cocoid']] = new_item 71 | elif json_item['split'] == 'test': 72 | self.karpathy_test_dict[json_item['cocoid']] = new_item 73 | elif json_item['split'] == 'val': 74 | self.karpathy_val_dict[json_item['cocoid']] = new_item 75 | 76 | self.karpathy_train_list = [] 77 | self.karpathy_val_list = [] 78 | self.karpathy_test_list = [] 79 | for key in self.karpathy_train_dict.keys(): 80 | self.karpathy_train_list.append(self.karpathy_train_dict[key]) 81 | for key in self.karpathy_val_dict.keys(): 82 | self.karpathy_val_list.append(self.karpathy_val_dict[key]) 83 | for key in self.karpathy_test_dict.keys(): 84 | self.karpathy_test_list.append(self.karpathy_test_dict[key]) 85 | 86 | self.train_num_images = len(self.karpathy_train_list) 87 | self.val_num_images = len(self.karpathy_val_list) 88 | self.test_num_images = len(self.karpathy_test_list) 89 | 90 | if limited_num_train_images is not None: 91 | self.karpathy_train_list = self.karpathy_train_list[:limited_num_train_images] 92 | self.train_num_images = limited_num_train_images 93 | if limited_num_val_images is not None: 94 | self.karpathy_val_list = self.karpathy_val_list[:limited_num_val_images] 95 | self.val_num_images = limited_num_val_images 96 | if limited_num_test_images is not None: 97 | self.karpathy_test_list = self.karpathy_test_list[:limited_num_test_images] 98 | self.test_num_images = limited_num_test_images 99 | 100 | if verbose: 101 | print("Num train images: " + str(self.train_num_images)) 102 | print("Num val images: " + str(self.val_num_images)) 103 | print("Num test images: " + str(self.test_num_images)) 104 | 105 | tokenized_captions_list = [] 106 | for i in range(self.train_num_images): 107 | for caption in self.karpathy_train_list[i]['captions']: 108 | tmp = language_utils.lowercase_and_clean_trailing_spaces([caption]) 109 | tmp = language_utils.add_space_between_non_alphanumeric_symbols(tmp) 110 | tmp = language_utils.remove_punctuations(tmp) 111 | tokenized_caption = ['SOS'] + language_utils.tokenize(tmp)[0] + ['EOS'] 112 | tokenized_captions_list.append(tokenized_caption) 113 | 114 | counter_dict = dict() 115 | for i in range(len(tokenized_captions_list)): 116 | for word in tokenized_captions_list[i]: 117 | if word not in counter_dict: 118 | counter_dict[word] = 1 119 | else: 120 | counter_dict[word] += 1 121 | 122 | less_than_min_occurrences_set = set() 123 | for k, v in counter_dict.items(): 124 | if v < dict_min_occurrences: 125 | less_than_min_occurrences_set.add(k) 126 | if verbose: 127 | print("tot tokens " + str(len(counter_dict)) + 128 | " less than " + str(dict_min_occurrences) + ": " + str(len(less_than_min_occurrences_set)) + 129 | " remaining: " + str(len(counter_dict) - len(less_than_min_occurrences_set))) 130 | 131 | self.num_caption_vocab = 4 132 | self.max_seq_len = 0 133 | discovered_words = ['PAD', 'SOS', 'EOS', 'UNK'] 134 | for i in range(len(tokenized_captions_list)): 135 | caption = tokenized_captions_list[i] 136 | if len(caption) > self.max_seq_len: 137 | self.max_seq_len = len(caption) 138 | for word in caption: 139 | if (word not in discovered_words) and (not word in less_than_min_occurrences_set): 140 | discovered_words.append(word) 141 | self.num_caption_vocab += 1 142 | 143 | discovered_words.sort() 144 | self.caption_word2idx_dict = dict() 145 | self.caption_idx2word_list = [] 146 | for i in range(len(discovered_words)): 147 | self.caption_word2idx_dict[discovered_words[i]] = i 148 | self.caption_idx2word_list.append(discovered_words[i]) 149 | if verbose: 150 | print("There are " + str(self.num_caption_vocab) + " vocabs in dict") 151 | 152 | def get_image_path(self, img_idx, dataset_split): 153 | 154 | if dataset_split == CocoDatasetKarpathy.TestSet_ID: 155 | img_path = self.karpathy_test_list[img_idx]['img_path'] 156 | img_id = self.karpathy_test_list[img_idx]['img_id'] 157 | elif dataset_split == CocoDatasetKarpathy.ValidationSet_ID: 158 | img_path = self.karpathy_val_list[img_idx]['img_path'] 159 | img_id = self.karpathy_val_list[img_idx]['img_id'] 160 | else: 161 | img_path = self.karpathy_train_list[img_idx]['img_path'] 162 | img_id = self.karpathy_train_list[img_idx]['img_id'] 163 | 164 | return img_path, img_id 165 | 166 | def get_all_images_captions(self, dataset_split): 167 | all_image_references = [] 168 | 169 | if dataset_split == CocoDatasetKarpathy.TestSet_ID: 170 | dataset = self.karpathy_test_list 171 | elif dataset_split == CocoDatasetKarpathy.ValidationSet_ID: 172 | dataset = self.karpathy_val_list 173 | else: 174 | dataset = self.karpathy_train_list 175 | 176 | for img_idx in range(len(dataset)): 177 | all_image_references.append(dataset[img_idx]['captions']) 178 | return all_image_references 179 | 180 | def get_eos_token_idx(self): 181 | return self.caption_word2idx_dict['EOS'] 182 | 183 | def get_sos_token_idx(self): 184 | return self.caption_word2idx_dict['SOS'] 185 | 186 | def get_pad_token_idx(self): 187 | return self.caption_word2idx_dict['PAD'] 188 | 189 | def get_unk_token_idx(self): 190 | return self.caption_word2idx_dict['UNK'] 191 | 192 | def get_eos_token_str(self): 193 | return 'EOS' 194 | 195 | def get_sos_token_str(self): 196 | return 'SOS' 197 | 198 | def get_pad_token_str(self): 199 | return 'PAD' 200 | 201 | def get_unk_token_str(self): 202 | return 'UNK' 203 | -------------------------------------------------------------------------------- /data/transparent_data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This dataloader inherently represents also a training session that can be saved and loaded 3 | """ 4 | 5 | 6 | class TransparentDataLoader: 7 | def __init__(self): 8 | super(TransparentDataLoader, self).__init__() 9 | 10 | # initialize the training on a specific epoch 11 | def init_epoch(self, epoch, batch_size): 12 | raise NotImplementedError 13 | 14 | def get_next_batch(self): 15 | raise NotImplementedError 16 | 17 | # methods for progress saving and loading progress of the data loader 18 | def set_epoch_it(self, epoch): 19 | raise NotImplementedError 20 | 21 | def get_epoch_it(self): 22 | raise NotImplementedError 23 | 24 | def get_num_epoch(self): 25 | raise NotImplementedError 26 | 27 | def get_num_batches(self): 28 | raise NotImplementedError 29 | 30 | def set_batch_it(self, batch_it): 31 | raise NotImplementedError 32 | 33 | def get_batch_it(self): 34 | raise NotImplementedError 35 | 36 | def get_batch_size(self): 37 | raise NotImplementedError 38 | 39 | def save_state(self): 40 | raise NotImplementedError 41 | 42 | def load_state(self, state_dict): 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from PIL import Image as PIL_Image 4 | import torchvision 5 | import torch 6 | import argparse 7 | from argparse import Namespace 8 | from torch.nn.parameter import Parameter 9 | from time import time 10 | 11 | from data.coco_dataset import CocoDatasetKarpathy 12 | 13 | torch.autograd.set_detect_anomaly(False) 14 | torch.set_num_threads(1) 15 | torch.set_num_interop_threads(1) 16 | import functools 17 | print = functools.partial(print, flush=True) 18 | DEFAULT_RANK = 0 19 | 20 | 21 | def convert_time_as_hhmmss(ticks): 22 | return str(int(ticks / 60)) + " m " + \ 23 | str(int(ticks) % 60) + " s" 24 | 25 | 26 | def generate_data(path_args): 27 | 28 | coco_dataset = CocoDatasetKarpathy(images_path=path_args.images_path, 29 | coco_annotations_path=args.captions_path + "dataset_coco.json", 30 | preproc_images_hdf5_filepath=None, 31 | precalc_features_hdf5_filepath=None, 32 | limited_num_train_images=None, 33 | limited_num_val_images=5000) 34 | 35 | from models.swin_transformer_mod import SwinTransformer 36 | model = SwinTransformer( 37 | img_size=384, patch_size=4, in_chans=3, 38 | embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], 39 | window_size=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, 40 | drop_rate=0.0, attn_drop_rate=0.0, 41 | drop_path_rate=0.0, 42 | norm_layer=torch.nn.LayerNorm, ape=False, patch_norm=True, 43 | use_checkpoint=False) 44 | 45 | def load_backbone_only_from_save(model, state_dict, prefix=None): 46 | own_state = model.state_dict() 47 | for name, param in state_dict.items(): 48 | if prefix is not None and name.startswith(prefix): 49 | name = name[len(prefix):] 50 | if name not in own_state: 51 | print("Not found: " + str(name)) 52 | continue 53 | if isinstance(param, Parameter): 54 | param = param.data 55 | own_state[name].copy_(param) 56 | print("Found: " + str(name)) 57 | 58 | save_model_path = path_args.save_model_path 59 | map_location = {'cuda:%d' % DEFAULT_RANK: 'cuda:%d' % DEFAULT_RANK} 60 | checkpoint = torch.load(save_model_path, map_location=map_location) 61 | if 'model_state_dict' in checkpoint.keys(): 62 | print("Custom save point found") 63 | load_backbone_only_from_save(model, checkpoint['model_state_dict'], prefix='swin_transf.') 64 | else: 65 | print("Custom save point not found") 66 | load_backbone_only_from_save(model, checkpoint['model'], prefix=None) 67 | print("Loading phase ended") 68 | 69 | model = model.to(DEFAULT_RANK) 70 | 71 | test_preprocess_layers_1 = [torchvision.transforms.Resize((384, 384))] 72 | test_preprocess_layers_2 = [torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 73 | test_preprocess_1 = torchvision.transforms.Compose(test_preprocess_layers_1) 74 | test_preprocess_2 = torchvision.transforms.Compose(test_preprocess_layers_2) 75 | 76 | model.eval() 77 | with torch.no_grad(): 78 | """ 79 | TIP: if you don't have 100GB of memory is too much for features allocation, 80 | try saving arrays into FP16. It shouldn't change affect much the result. 81 | Replace each line into: 82 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu())) 83 | into: 84 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16) 85 | we kept it FP32 for coherence with the experimental setup of the paper. 86 | """ 87 | hdf5_file = h5py.File(path_args.output_path, 'w') 88 | 89 | def apply_model(model, file_path): 90 | pil_image = PIL_Image.open(file_path) 91 | if pil_image.mode != 'RGB': 92 | pil_image = PIL_Image.new("RGB", pil_image.size) 93 | preprocess_pil_image = test_preprocess_1(pil_image) 94 | tens_image = torchvision.transforms.ToTensor()(preprocess_pil_image) 95 | tens_image = test_preprocess_2(tens_image).to(DEFAULT_RANK) 96 | output = model(tens_image.unsqueeze(0)) 97 | return output.squeeze(0) 98 | 99 | for i in range(coco_dataset.train_num_images): 100 | img_path, img_id = coco_dataset.get_image_path(coco_dataset.train_num_images - i - 1, 101 | CocoDatasetKarpathy.TrainSet_ID) 102 | output = apply_model(model, img_path) 103 | if path_args.dtype == 'fp16': 104 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16) 105 | else: 106 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu())) 107 | if (i+1) % 5000 == 0 or (i+1) == coco_dataset.train_num_images: 108 | print("Train " + str(i+1) + " / " + str(coco_dataset.train_num_images) + " completed") 109 | 110 | for i in range(coco_dataset.test_num_images): 111 | img_path, img_id = coco_dataset.get_image_path(i, CocoDatasetKarpathy.TestSet_ID) 112 | output = apply_model(model, img_path) 113 | if path_args.dtype == 'fp16': 114 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16) 115 | else: 116 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu())) 117 | if (i+1) % 2500 == 0 or (i+1) == coco_dataset.test_num_images: 118 | print("Test " + str(i+1) + " / " + str(coco_dataset.test_num_images) + " completed") 119 | 120 | for i in range(coco_dataset.val_num_images): 121 | img_path, img_id = coco_dataset.get_image_path(i, CocoDatasetKarpathy.ValidationSet_ID) 122 | output = apply_model(model, img_path) 123 | if path_args.dtype == 'fp16': 124 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16) 125 | else: 126 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu())) 127 | if (i+1) % 2500 == 0 or (i+1) == coco_dataset.test_num_images: 128 | print("Val " + str(i+1) + " / " + str(coco_dataset.test_num_images) + " completed") 129 | 130 | print("[GPU: " + str(DEFAULT_RANK) + " ] Closing...") 131 | 132 | 133 | if __name__ == "__main__": 134 | 135 | parser = argparse.ArgumentParser(description='Image Captioning') 136 | parser.add_argument('--save_model_path', type=str, default='./github_ignore_material/saves/') 137 | parser.add_argument('--output_path', type=str, default='./github_ignore_material/raw_data/precalc_features.hdf5') 138 | parser.add_argument('--images_path', type=str, default='/tmp/images/') 139 | parser.add_argument('--captions_path', type=str, default='./github_ignore_material/raw_data/') 140 | parser.add_argument('--dtype', type=str, default='fp32', help='Decide data type of saved features') 141 | 142 | args = parser.parse_args() 143 | 144 | path_args = Namespace(save_model_path=args.save_model_path, 145 | output_path=args.output_path, 146 | images_path=args.images_path, 147 | captions_path=args.captions_path, 148 | dtype=args.dtype) 149 | generate_data(path_args=path_args) 150 | 151 | 152 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Small script for testing on few generic images given the model weights. 3 | In order to minimize the requirements, it runs only on CPU and images are 4 | processed one by one. 5 | """ 6 | 7 | import torch 8 | import argparse 9 | import pickle 10 | from argparse import Namespace 11 | 12 | from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 13 | from utils.image_utils import preprocess_image 14 | from utils.language_utils import tokens2description 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser(description='Demo') 20 | parser.add_argument('--model_dim', type=int, default=512) 21 | parser.add_argument('--N_enc', type=int, default=3) 22 | parser.add_argument('--N_dec', type=int, default=3) 23 | parser.add_argument('--max_seq_len', type=int, default=74) 24 | parser.add_argument('--device', type=str, default='cpu') 25 | parser.add_argument('--load_path', type=str, default='./rf_model.pth') 26 | parser.add_argument('--image_paths', type=str, 27 | default=['./demo_material/tatin.jpg', 28 | './demo_material/micheal.jpg', 29 | './demo_material/napoleon.jpg', 30 | './demo_material/cat_girl.jpg'], 31 | nargs='+') 32 | parser.add_argument('--beam_size', type=int, default=5) 33 | 34 | args = parser.parse_args() 35 | 36 | drop_args = Namespace(enc=0.0, 37 | dec=0.0, 38 | enc_input=0.0, 39 | dec_input=0.0, 40 | other=0.0) 41 | model_args = Namespace(model_dim=args.model_dim, 42 | N_enc=args.N_enc, 43 | N_dec=args.N_dec, 44 | drop_args=drop_args) 45 | 46 | with open('./demo_material/demo_coco_tokens.pickle', 'rb') as f: 47 | coco_tokens = pickle.load(f) 48 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']] 49 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']] 50 | 51 | print("Dictionary loaded ...") 52 | 53 | img_size = 384 54 | model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, 55 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], 56 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, 57 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, 58 | swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, 59 | swin_use_checkpoint=False, 60 | final_swin_dim=1536, 61 | 62 | d_model=model_args.model_dim, N_enc=model_args.N_enc, 63 | N_dec=model_args.N_dec, num_heads=8, ff=2048, 64 | num_exp_enc_list=[32, 64, 128, 256, 512], 65 | num_exp_dec=16, 66 | output_word2idx=coco_tokens['word2idx_dict'], 67 | output_idx2word=coco_tokens['idx2word_list'], 68 | max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, 69 | rank=args.device) 70 | checkpoint = torch.load(args.load_path, map_location=torch.device(args.device)) 71 | model.load_state_dict(checkpoint['model_state_dict']) 72 | print("Model loaded ...") 73 | 74 | input_images = [] 75 | for path in args.image_paths: 76 | input_images.append(preprocess_image(path, img_size)) 77 | 78 | print("Generating captions ...\n") 79 | for i in range(len(input_images)): 80 | path = args.image_paths[i] 81 | image = input_images[i] 82 | beam_search_kwargs = {'beam_size': args.beam_size, 83 | 'beam_max_seq_len': args.max_seq_len, 84 | 'sample_or_max': 'max', 85 | 'how_many_outputs': 1, 86 | 'sos_idx': sos_idx, 87 | 'eos_idx': eos_idx} 88 | with torch.no_grad(): 89 | pred, _ = model(enc_x=image, 90 | enc_x_num_pads=[0], 91 | mode='beam_search', **beam_search_kwargs) 92 | pred = tokens2description(pred[0][0], coco_tokens['idx2word_list'], sos_idx, eos_idx) 93 | print(path + ' \n\tDescription: ' + pred + '\n') 94 | 95 | print("Closed.") 96 | -------------------------------------------------------------------------------- /demo_material/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/__init__.py -------------------------------------------------------------------------------- /demo_material/cat_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/cat_girl.jpg -------------------------------------------------------------------------------- /demo_material/demo_coco_tokens.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/demo_coco_tokens.pickle -------------------------------------------------------------------------------- /demo_material/micheal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/micheal.jpg -------------------------------------------------------------------------------- /demo_material/napoleon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/napoleon.jpg -------------------------------------------------------------------------------- /demo_material/tatin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/tatin.jpg -------------------------------------------------------------------------------- /demo_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_results.png -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /eval/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from eval.bleu.bleu_scorer import BleuScorer 12 | 13 | class Bleu: 14 | def __init__(self, n=4): 15 | # default compute Blue score up to 4 16 | self._n = n 17 | self._hypo_for_image = {} 18 | self.ref_for_image = {} 19 | 20 | def compute_score(self, gts, res): 21 | 22 | assert(gts.keys() == res.keys()) 23 | imgIds = gts.keys() 24 | 25 | bleu_scorer = BleuScorer(n=self._n) 26 | for id in imgIds: 27 | hypo = res[id] 28 | ref = gts[id] 29 | 30 | # Sanity check. 31 | assert(type(hypo) is list) 32 | assert(len(hypo) == 1) 33 | assert(type(ref) is list) 34 | assert(len(ref) >= 1) 35 | 36 | bleu_scorer += (hypo[0], ref) 37 | 38 | #score, scores = bleu_scorer.compute_score(option='shortest') 39 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 40 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 41 | 42 | # return (bleu, bleu_info) 43 | return score, scores 44 | 45 | def method(self): 46 | return "Bleu" 47 | -------------------------------------------------------------------------------- /eval/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, reflen, refmaxcounts, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.items(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | reflen, refmaxcounts = self.crefs[-1] # python2.7 to python3.5 ADAPTATION 116 | cooked_test = cook_test(test, reflen, refmaxcounts) 117 | self.ctest.append(cooked_test) ## N.B.: -1 118 | else: 119 | self.ctest.append(None) # lens of crefs and ctest have to match 120 | 121 | self._score = None ## need to recompute 122 | 123 | def ratio(self, option=None): 124 | self.compute_score(option=option) 125 | return self._ratio 126 | 127 | def score_ratio(self, option=None): 128 | '''return (bleu, len_ratio) pair''' 129 | return (self.fscore(option=option), self.ratio(option=option)) 130 | 131 | def score_ratio_str(self, option=None): 132 | return "%.4f (%.2f)" % self.score_ratio(option) 133 | 134 | def reflen(self, option=None): 135 | self.compute_score(option=option) 136 | return self._reflen 137 | 138 | def testlen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._testlen 141 | 142 | def retest(self, new_test): 143 | if type(new_test) is str: 144 | new_test = [new_test] 145 | assert len(new_test) == len(self.crefs), new_test 146 | self.ctest = [] 147 | for t, rs in zip(new_test, self.crefs): 148 | self.ctest.append(cook_test(t, rs)) 149 | self._score = None 150 | 151 | return self 152 | 153 | def rescore(self, new_test): 154 | ''' replace test(s) with new test(s), and returns the new score.''' 155 | 156 | return self.retest(new_test).compute_score() 157 | 158 | def size(self): 159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 160 | return len(self.crefs) 161 | 162 | def __iadd__(self, other): 163 | '''add an instance (e.g., from another sentence).''' 164 | 165 | if type(other) is tuple: 166 | ## avoid creating new BleuScorer instances 167 | self.cook_append(other[0], other[1]) 168 | else: 169 | assert self.compatible(other), "incompatible BLEUs." 170 | self.ctest.extend(other.ctest) 171 | self.crefs.extend(other.crefs) 172 | self._score = None ## need to recompute 173 | 174 | return self 175 | 176 | def compatible(self, other): 177 | return isinstance(other, BleuScorer) and self.n == other.n 178 | 179 | def single_reflen(self, option="average"): 180 | return self._single_reflen(self.crefs[0][0], option) 181 | 182 | def _single_reflen(self, reflens, option=None, testlen=None): 183 | 184 | if option == "shortest": 185 | reflen = min(reflens) 186 | elif option == "average": 187 | reflen = float(sum(reflens))/len(reflens) 188 | elif option == "closest": 189 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 190 | else: 191 | assert False, "unsupported reflen option %s" % option 192 | 193 | return reflen 194 | 195 | def recompute_score(self, option=None, verbose=0): 196 | self._score = None 197 | return self.compute_score(option, verbose) 198 | 199 | def compute_score(self, option=None, verbose=0): 200 | n = self.n 201 | small = 1e-9 202 | tiny = 1e-15 ## so that if guess is 0 still return 0 203 | bleu_list = [[] for _ in range(n)] 204 | 205 | if self._score is not None: 206 | return self._score 207 | 208 | if option is None: 209 | option = "average" if len(self.crefs) == 1 else "closest" 210 | 211 | self._testlen = 0 212 | self._reflen = 0 213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 214 | 215 | # for each sentence 216 | for comps in self.ctest: 217 | testlen = comps['testlen'] 218 | self._testlen += testlen 219 | 220 | if self.special_reflen is None: ## need computation 221 | reflen = self._single_reflen(comps['reflen'], option, testlen) 222 | else: 223 | reflen = self.special_reflen 224 | 225 | self._reflen += reflen 226 | 227 | for key in ['guess','correct']: 228 | for k in range(n): 229 | totalcomps[key][k] += comps[key][k] 230 | 231 | # append per image bleu score 232 | bleu = 1. 233 | for k in range(n): 234 | bleu *= (float(comps['correct'][k]) + tiny) \ 235 | /(float(comps['guess'][k]) + small) 236 | bleu_list[k].append(bleu ** (1./(k+1))) 237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 238 | if ratio < 1: 239 | for k in range(n): 240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 241 | 242 | if verbose > 1: 243 | print(comps, reflen) 244 | 245 | totalcomps['reflen'] = self._reflen 246 | totalcomps['testlen'] = self._testlen 247 | 248 | bleus = [] 249 | bleu = 1. 250 | for k in range(n): 251 | bleu *= float(totalcomps['correct'][k] + tiny) \ 252 | / (totalcomps['guess'][k] + small) 253 | bleus.append(bleu ** (1./(k+1))) 254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 255 | if ratio < 1: 256 | for k in range(n): 257 | bleus[k] *= math.exp(1 - 1/ratio) 258 | 259 | if verbose > 0: 260 | pass 261 | #print(totalcomps) 262 | #print("ratio:", ratio) 263 | 264 | self._score = bleus 265 | return self._score, bleu_list 266 | -------------------------------------------------------------------------------- /eval/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from eval.cider.cider_scorer import CiderScorer 11 | 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" -------------------------------------------------------------------------------- /eval/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /eval/cider/reinforce_cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | # ReinforceCIDEr is an alternative implementation of CIDEr where the corpus is initialized in the constructor 11 | # so it doesn't need to be processed again every time we need to compute the cider score 12 | # in the Self Critical Learning Process. --- Jia Cheng 13 | 14 | from eval.cider.reinforce_cider_scorer import ReinforceCiderScorer 15 | import pdb 16 | 17 | 18 | class ReinforceCider: 19 | 20 | # The batch_ref_sentences will be a small sample of the original corpus, note however that there's no need of 21 | # correspondence of img_ids between img_ids in the corpus and the ones in the batch_ref_sentences, the img_ids 22 | # consistency is required between batch_ref_sentences and batch_test_sentences only. 23 | def __init__(self, corpus, n=4, sigma=6.0): 24 | ''' 25 | Corpus represents the collection of reference sentences for each image, this must be a dictionary with image 26 | ids as keys and a list of sentences as value. 27 | 28 | :param corpus: a dictionary with 29 | :param n: number of n-grams 30 | :param sigma: length penalty coefficient 31 | ''' 32 | # set cider to sum over 1 to 4-grams 33 | self._n = n 34 | # set the standard deviation parameter for gaussian penalty 35 | self._sigma = sigma 36 | self.cider_scorer = ReinforceCiderScorer(corpus, n=self._n, sigma=self._sigma) 37 | 38 | def compute_score(self, hypo, refs): 39 | """ 40 | Main function to compute CIDEr score 41 | :param hypo_for_image (dict) : dictionary with key and value 42 | ref_for_image (dict) : dictionary with key and value 43 | :return: cider (float) : computed CIDEr score for the corpus 44 | """ 45 | 46 | # assert(hypo.keys() == refs.keys()) 47 | 48 | (score, scores) = self.cider_scorer.compute_score(refs, hypo) 49 | 50 | return score, scores 51 | 52 | def method(self): 53 | return "Reinforce CIDEr" -------------------------------------------------------------------------------- /eval/cider/reinforce_cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | # ReinforceCIDErScorer is an alternative implementation of CIDEr Scorer according to my personal 6 | # input output format "tastes" and needs --- Jia Cheng 7 | 8 | import copy 9 | from collections import defaultdict 10 | import numpy as np 11 | import pdb 12 | import math 13 | 14 | 15 | def precook(s, n=4, out=False): 16 | """ 17 | Takes a string as input and returns an object that can be given to 18 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 19 | can take string arguments as well. 20 | :param s: string : sentence to be converted into ngrams 21 | :param n: int : number of ngrams for which representation is calculated 22 | :return: term frequency vector for occuring ngrams 23 | """ 24 | words = s.split() 25 | counts = defaultdict(int) 26 | for k in range(1, n + 1): 27 | for i in range(len(words) - k + 1): 28 | ngram = tuple(words[i:i + k]) 29 | counts[ngram] += 1 30 | return counts 31 | 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | 44 | def cook_test(test, n=4): 45 | '''Takes a test sentence and returns an object that 46 | encapsulates everything that BLEU needs to know about it. 47 | :param test: list of string : hypothesis sentence for some image 48 | :param n: int : number of ngrams for which (ngram) representation is calculated 49 | :return: result (dict) 50 | ''' 51 | return precook(test, n, True) 52 | 53 | 54 | class ReinforceCiderScorer(object): 55 | """CIDEr scorer. 56 | """ 57 | 58 | def __init__(self, corpus_crefs, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | 63 | df_crefs = [] 64 | for refs in corpus_crefs: 65 | df_crefs.append(cook_refs(refs)) 66 | 67 | self.document_frequency = self.compute_doc_freq(df_crefs) 68 | self.corpus_ref_len = np.log(float(len(df_crefs))) 69 | 70 | 71 | def compute_doc_freq(self, df_crefs): 72 | ''' 73 | Compute term frequency for reference data. 74 | This will be used to compute idf (inverse document frequency later) 75 | The term frequency is stored in the object 76 | :return: None 77 | ''' 78 | document_frequency = defaultdict(float) 79 | for refs in df_crefs: 80 | # refs, k ref captions of one image 81 | for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]): 82 | document_frequency[ngram] += 1 83 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 84 | return document_frequency 85 | 86 | def compute_cider(self, refs, tests): 87 | 88 | ctest = [] 89 | crefs = [] 90 | for idx in range(len(tests)): 91 | test = tests[idx] 92 | ref = refs[idx] 93 | ctest.append(cook_test(test)) 94 | crefs.append(cook_refs(ref)) 95 | 96 | def counts2vec(cnts): 97 | """ 98 | Function maps counts of ngram to vector of tfidf weights. 99 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 100 | The n-th entry of array denotes length of n-grams. 101 | :param cnts: 102 | :param ref_len: 103 | :return: vec (array of dict), norm (array of float), length (int) 104 | """ 105 | vec = [defaultdict(float) for _ in range(self.n)] 106 | length = 0 107 | norm = [0.0 for _ in range(self.n)] 108 | for (ngram, term_freq) in cnts.items(): 109 | # give word count 1 if it doesn't appear in reference corpus 110 | df = np.log(max(1.0, self.document_frequency[ngram])) 111 | # ngram index 112 | n = len(ngram) - 1 113 | # tf (term_freq) * idf (precomputed idf) for n-grams 114 | vec[n][ngram] = float(term_freq) * (self.corpus_ref_len - df) 115 | # compute norm for the vector. the norm will be used for computing similarity 116 | norm[n] += pow(vec[n][ngram], 2) 117 | 118 | if n == 1: 119 | length += term_freq 120 | norm = [np.sqrt(n) for n in norm] 121 | return vec, norm, length 122 | 123 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 124 | ''' 125 | Compute the cosine similarity of two vectors. 126 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 127 | :param vec_ref: array of dictionary for vector corresponding to reference 128 | :param norm_hyp: array of float for vector corresponding to hypothesis 129 | :param norm_ref: array of float for vector corresponding to reference 130 | :param length_hyp: int containing length of hypothesis 131 | :param length_ref: int containing length of reference 132 | :return: array of score for each n-grams cosine similarity 133 | ''' 134 | delta = float(length_hyp - length_ref) 135 | # measure consine similarity 136 | val = np.array([0.0 for _ in range(self.n)]) 137 | for n in range(self.n): 138 | # ngram 139 | for (ngram, count) in vec_hyp[n].items(): 140 | # vrama91 : added clipping 141 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 142 | 143 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 144 | val[n] /= (norm_hyp[n] * norm_ref[n]) 145 | 146 | assert (not math.isnan(val[n])) 147 | # vrama91: added a length based gaussian penalty 148 | val[n] *= np.e ** (-(delta ** 2) / (2 * self.sigma ** 2)) 149 | return val 150 | 151 | # compute log reference length 152 | 153 | scores = [] 154 | for test, refs in zip(ctest, crefs): 155 | # compute vector for test captions 156 | vec, norm, length = counts2vec(test) 157 | # compute vector for ref captions 158 | score = np.array([0.0 for _ in range(self.n)]) 159 | for ref in refs: 160 | vec_ref, norm_ref, length_ref = counts2vec(ref) 161 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 162 | # change by vrama91 - mean of ngram scores, instead of sum 163 | score_avg = np.mean(score) 164 | # divide by number of references 165 | score_avg /= len(refs) 166 | # multiply score by 10 167 | score_avg *= 10.0 168 | # append score of an image to the score list 169 | scores.append(score_avg) 170 | return scores 171 | 172 | """ 173 | refs must be contained in the dataset for document frequency 174 | """ 175 | def compute_score(self, refs, tests, option=None, verbose=0): 176 | # compute idf 177 | # self.compute_doc_freq() 178 | # assert to check document frequency 179 | # assert(len(self.ctest) >= max(self.document_frequency.values())) 180 | # compute cider score 181 | score = self.compute_cider(refs, tests) 182 | # debug 183 | # print score 184 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from eval.tokenizer.ptbtokenizer import PTBTokenizer 3 | from eval.bleu.bleu import Bleu 4 | from eval.meteor.meteor import Meteor 5 | from eval.rouge.rouge import Rouge 6 | from eval.cider.cider import Cider 7 | from eval.spice.spice import Spice 8 | 9 | 10 | """ 11 | I do not own the rights of this code, I just modified it according to my needs. 12 | The original version can be found in: 13 | https://github.com/cocodataset/cocoapi 14 | """ 15 | class COCOEvalCap: 16 | def __init__(self, dataset_gts_anns, pred_anns, pred_img_ids, get_stanford_models_path=None): 17 | self.evalImgs = [] 18 | self.eval = {} 19 | self.imgToEval = {} 20 | self.dataset_gts_anns = dataset_gts_anns 21 | self.pred_anns = pred_anns 22 | self.pred_img_ids = pred_img_ids 23 | 24 | import subprocess 25 | # print("invoking " + str(get_stanford_models_path)) 26 | rc = subprocess.call(get_stanford_models_path) 27 | 28 | def evaluate(self, bleu=True, rouge=True, cider=True, spice=True, meteor=True, verbose=True): 29 | # imgIds = self.coco.getImgIds() 30 | gts = {} 31 | res = {} 32 | for imgId in self.pred_img_ids: 33 | gts[imgId] = self.dataset_gts_anns[imgId] 34 | res[imgId] = self.pred_anns[imgId] 35 | 36 | # ================================================= 37 | # Set up scorers 38 | # ================================================= 39 | #if verbose: 40 | # print('tokenization...') 41 | tokenizer = PTBTokenizer() 42 | gts = tokenizer.tokenize(gts) 43 | res = tokenizer.tokenize(res) 44 | 45 | # ================================================= 46 | # Set up scorers 47 | # ================================================= 48 | #if verbose: 49 | # print('setting up scorers...') 50 | scorers = [] 51 | if cider: 52 | scorers.append((Cider(), "CIDEr")) 53 | if bleu: 54 | scorers.append((Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"])) 55 | if rouge: 56 | scorers.append((Rouge(), "ROUGE_L")) 57 | if spice: 58 | scorers.append((Spice(), "SPICE")) 59 | if meteor: 60 | scorers.append((Meteor(), "METEOR")) 61 | """ 62 | scorers = [ 63 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 64 | (Rouge(), "ROUGE_L"), 65 | (Cider(), "CIDEr"), 66 | (Spice(), "SPICE"), 67 | (Meteor(), "METEOR"), 68 | ] 69 | """ 70 | 71 | # ================================================= 72 | # Compute scores 73 | # ================================================= 74 | return_scores = [] 75 | for scorer, method in scorers: 76 | if verbose: 77 | # print('computing %s score...'%(scorer.method())) 78 | pass 79 | score, scores = scorer.compute_score(gts, res) 80 | if type(method) == list: 81 | for sc, scs, m in zip(score, scores, method): 82 | self.setEval(sc, m) 83 | self.setImgToEvalImgs(scs, gts.keys(), m) 84 | if verbose: 85 | # print("%s: %0.3f"%(m, sc)) 86 | pass 87 | return_scores.append((m, round(sc, 4))) 88 | else: 89 | self.setEval(score, method) 90 | self.setImgToEvalImgs(scores, gts.keys(), method) 91 | if verbose: 92 | # print("%s: %0.3f"%(method, score)) 93 | pass 94 | return_scores.append((method, round(score, 4))) 95 | self.setEvalImgs() 96 | 97 | return return_scores 98 | 99 | def setEval(self, score, method): 100 | self.eval[method] = score 101 | 102 | def setImgToEvalImgs(self, scores, imgIds, method): 103 | for imgId, score in zip(imgIds, scores): 104 | if not imgId in self.imgToEval: 105 | self.imgToEval[imgId] = {} 106 | self.imgToEval[imgId]["image_id"] = imgId 107 | self.imgToEval[imgId][method] = score 108 | 109 | def setEvalImgs(self): 110 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /eval/get_stanford_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Stanford CoreNLP models. 3 | 4 | CORENLP=stanford-corenlp-full-2015-12-09 5 | SPICELIB=./spice/lib 6 | JAR=stanford-corenlp-3.6.0 7 | 8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 9 | cd $DIR 10 | 11 | if [ -f $SPICELIB/$JAR.jar ]; then 12 | # echo "Found Stanford CoreNLP." 13 | : 14 | else 15 | echo "Downloading..." 16 | wget http://nlp.stanford.edu/software/$CORENLP.zip 17 | echo "Unzipping..." 18 | unzip $CORENLP.zip -d $SPICELIB/ 19 | mv $SPICELIB/$CORENLP/$JAR.jar $SPICELIB/ 20 | mv $SPICELIB/$CORENLP/$JAR-models.jar $SPICELIB/ 21 | rm -f $CORENLP.zip 22 | rm -rf $SPICELIB/$CORENLP/ 23 | echo "Done." 24 | fi 25 | -------------------------------------------------------------------------------- /eval/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /eval/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /eval/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /eval/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | # modified according to: https://github.com/tylin/coco-caption/issues/27 7 | # to support python3.5 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os 14 | import sys 15 | import subprocess 16 | import threading 17 | 18 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 19 | METEOR_JAR = 'meteor-1.5.jar' 20 | 21 | 22 | # print METEOR_JAR 23 | 24 | class Meteor: 25 | 26 | def __init__(self): 27 | self.env = os.environ 28 | self.env['LC_ALL'] = 'en_US.UTF_8' 29 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, 30 | '-', '-', '-stdio', '-l', 'en', '-norm'] 31 | self.meteor_p = subprocess.Popen(self.meteor_cmd, 32 | cwd=os.path.dirname(os.path.abspath(__file__)), 33 | stdin=subprocess.PIPE, 34 | stdout=subprocess.PIPE, 35 | stderr=subprocess.PIPE, 36 | env=self.env, universal_newlines=True, bufsize=1) 37 | # Used to guarantee thread safety 38 | self.lock = threading.Lock() 39 | 40 | def compute_score(self, gts, res): 41 | assert (gts.keys() == res.keys()) 42 | imgIds = sorted(list(gts.keys())) 43 | scores = [] 44 | 45 | eval_line = 'EVAL' 46 | self.lock.acquire() 47 | for i in imgIds: 48 | assert (len(res[i]) == 1) 49 | # There's a situation that the prediction is all punctuations 50 | # (see definition of PUNCTUATIONS in pycocoevalcap/tokenizer/ptbtokenizer.py) 51 | # then the prediction will become [''] after tokenization 52 | # which means res[i][0] == '' and self._stat will failed with this input 53 | if len(res[i][0]) == 0: 54 | res[i][0] = 'a' 55 | stat = self._stat(res[i][0], gts[i]) 56 | eval_line += ' ||| {}'.format(stat) 57 | 58 | # Send to METEOR 59 | self.meteor_p.stdin.write(eval_line + '\n') 60 | 61 | # Collect segment scores 62 | for i in range(len(imgIds)): 63 | score = float(self.meteor_p.stdout.readline().strip()) 64 | scores.append(score) 65 | 66 | # Final score 67 | final_score = float(self.meteor_p.stdout.readline().strip()) 68 | self.lock.release() 69 | 70 | return final_score, scores 71 | 72 | def method(self): 73 | return "METEOR" 74 | 75 | def _stat(self, hypothesis_str, reference_list): 76 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 77 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 78 | if sys.version_info[0] == 2: # python2 79 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)).encode('utf-8').strip() 80 | self.meteor_p.stdin.write(str(score_line + b'\n')) 81 | else: # assume python3+ 82 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)).strip() 83 | self.meteor_p.stdin.write(score_line + '\n') 84 | return self.meteor_p.stdout.readline().strip() 85 | 86 | def __del__(self): 87 | self.lock.acquire() 88 | self.meteor_p.stdin.close() 89 | self.meteor_p.kill() 90 | self.meteor_p.wait() 91 | self.lock.release() -------------------------------------------------------------------------------- /eval/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /eval/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /eval/spice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/__init__.py -------------------------------------------------------------------------------- /eval/spice/cache/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/cache/.gitkeep -------------------------------------------------------------------------------- /eval/spice/lib/Meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/Meteor-1.5.jar -------------------------------------------------------------------------------- /eval/spice/lib/SceneGraphParser-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/SceneGraphParser-1.0.jar -------------------------------------------------------------------------------- /eval/spice/lib/ejml-0.23.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/ejml-0.23.jar -------------------------------------------------------------------------------- /eval/spice/lib/fst-2.47.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/fst-2.47.jar -------------------------------------------------------------------------------- /eval/spice/lib/guava-19.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/guava-19.0.jar -------------------------------------------------------------------------------- /eval/spice/lib/hamcrest-core-1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/hamcrest-core-1.3.jar -------------------------------------------------------------------------------- /eval/spice/lib/jackson-core-2.5.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/jackson-core-2.5.3.jar -------------------------------------------------------------------------------- /eval/spice/lib/javassist-3.19.0-GA.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/javassist-3.19.0-GA.jar -------------------------------------------------------------------------------- /eval/spice/lib/json-simple-1.1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/json-simple-1.1.1.jar -------------------------------------------------------------------------------- /eval/spice/lib/junit-4.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/junit-4.12.jar -------------------------------------------------------------------------------- /eval/spice/lib/lmdbjni-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-0.4.6.jar -------------------------------------------------------------------------------- /eval/spice/lib/lmdbjni-linux64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-linux64-0.4.6.jar -------------------------------------------------------------------------------- /eval/spice/lib/lmdbjni-osx64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-osx64-0.4.6.jar -------------------------------------------------------------------------------- /eval/spice/lib/lmdbjni-win64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-win64-0.4.6.jar -------------------------------------------------------------------------------- /eval/spice/lib/objenesis-2.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/objenesis-2.4.jar -------------------------------------------------------------------------------- /eval/spice/lib/slf4j-api-1.7.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/slf4j-api-1.7.12.jar -------------------------------------------------------------------------------- /eval/spice/lib/slf4j-simple-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/slf4j-simple-1.7.21.jar -------------------------------------------------------------------------------- /eval/spice/spice-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/spice-1.0.jar -------------------------------------------------------------------------------- /eval/spice/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import sys 4 | import subprocess 5 | import threading 6 | import json 7 | import numpy as np 8 | import ast 9 | import tempfile 10 | import random 11 | 12 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 13 | SPICE_JAR = 'spice-1.0.jar' 14 | TEMP_DIR = 'tmp' 15 | CACHE_DIR = 'cache' 16 | 17 | 18 | class Spice: 19 | """ 20 | Main Class to compute the SPICE metric 21 | """ 22 | 23 | def float_convert(self, obj): 24 | try: 25 | return float(obj) 26 | except: 27 | return np.nan 28 | 29 | def compute_score(self, gts, res): 30 | assert (sorted(gts.keys()) == sorted(res.keys())) 31 | imgIds = sorted(gts.keys()) 32 | 33 | # Prepare temp input file for the SPICE scorer 34 | input_data = [] 35 | for id in imgIds: 36 | hypo = res[id] 37 | ref = gts[id] 38 | 39 | # Sanity check. 40 | assert (type(hypo) is list) 41 | assert (len(hypo) == 1) 42 | assert (type(ref) is list) 43 | assert (len(ref) >= 1) 44 | 45 | input_data.append({ 46 | "image_id": id, 47 | "test": hypo[0], 48 | "refs": ref 49 | }) 50 | 51 | cwd = os.path.dirname(os.path.abspath(__file__)) 52 | temp_dir = os.path.join(cwd, TEMP_DIR) 53 | if not os.path.exists(temp_dir): 54 | os.makedirs(temp_dir) 55 | 56 | # the generation of random names avoid very unlucky synchronization situations 57 | # with respect to the original implementation 58 | 59 | # python2.7 to python3.5 adaptation 60 | # in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 61 | import time 62 | random.seed(time.time()) 63 | random_int_str = str(random.randint(0, 9999999)) 64 | in_file_name = random_int_str + '_pid' + str(os.getpid()) + '_in_tmp_file.json' 65 | # print(temp_dir + '/' + in_file_name) 66 | with open(temp_dir + '/' + in_file_name, 'w') as in_file: 67 | json.dump(input_data, in_file, indent=2) 68 | 69 | # Start job 70 | # out_file_name = 'out_tmp_file.tmp' 71 | # with open(temp_dir + '/' + out_file_name, 'w') as out_file: 72 | 73 | out_file_name = random_int_str + '_pid' + str(os.getpid()) + '_out_tmp_file.json' 74 | out_file_path = temp_dir + '/' + out_file_name 75 | # create file 76 | with open(out_file_path, 'w') as f: 77 | f.write('') 78 | 79 | cache_dir = os.path.join(cwd, CACHE_DIR) 80 | if not os.path.exists(cache_dir): 81 | os.makedirs(cache_dir) 82 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, 83 | temp_dir + '/' + in_file_name, 84 | '-cache', cache_dir, 85 | '-out', out_file_path, 86 | '-subset', 87 | '-silent' 88 | ] 89 | subprocess.check_call(spice_cmd, cwd=os.path.dirname(os.path.abspath(__file__)), 90 | stdout=subprocess.DEVNULL, 91 | stderr=subprocess.DEVNULL) 92 | 93 | # Read and process results 94 | with open(temp_dir + '/' + out_file_name, 'r') as data_file: 95 | results = json.load(data_file) 96 | 97 | os.remove(temp_dir + '/' + in_file_name) 98 | os.remove(temp_dir + '/' + out_file_name) 99 | 100 | imgId_to_scores = {} 101 | spice_scores = [] 102 | for item in results: 103 | imgId_to_scores[item['image_id']] = item['scores'] 104 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 105 | average_score = np.mean(np.array(spice_scores)) 106 | scores = [] 107 | for image_id in imgIds: 108 | # Convert none to NaN before saving scores over subcategories 109 | score_set = {} 110 | for category, score_tuple in imgId_to_scores[image_id].items(): 111 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 112 | scores.append(score_set) 113 | return average_score, scores 114 | 115 | def method(self): 116 | return "SPICE" 117 | 118 | 119 | -------------------------------------------------------------------------------- /eval/spice/tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/tmp/.gitkeep -------------------------------------------------------------------------------- /eval/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /eval/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import random 15 | import tempfile 16 | import itertools 17 | 18 | # path to the stanford corenlp jar 19 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 20 | 21 | # punctuations to be removed from the sentences 22 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 23 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 24 | 25 | class PTBTokenizer: 26 | """Python wrapper of Stanford PTBTokenizer""" 27 | 28 | def tokenize(self, captions_for_image): 29 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 30 | 'edu.stanford.nlp.process.PTBTokenizer', \ 31 | '-preserveLines', '-lowerCase'] 32 | 33 | # ====================================================== 34 | # prepare data for PTB Tokenizer 35 | # ====================================================== 36 | final_tokenized_captions_for_image = {} 37 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 38 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 39 | 40 | # ====================================================== 41 | # save sentences to temporary file 42 | # ====================================================== 43 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 44 | tmp_file_name = path_to_jar_dirname + '/temp_file' + str(random.randint(0, 999999)) + '_pid_' + str(os.getpid()) + '.tmp' 45 | with open(tmp_file_name, 'w') as tmp_file: 46 | tmp_file.write(sentences) 47 | 48 | # ====================================================== 49 | # tokenize sentence 50 | # ====================================================== 51 | cmd.append(tmp_file_name) 52 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 53 | stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) 54 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 55 | # print(token_lines) 56 | token_lines = token_lines.decode("utf-8") 57 | lines = token_lines.split('\n') 58 | # remove temp file 59 | if os.path.isfile(tmp_file_name): 60 | os.remove(tmp_file_name) 61 | 62 | # ====================================================== 63 | # create dictionary for tokenized captions 64 | # ====================================================== 65 | for k, line in zip(image_id, lines): 66 | if not k in final_tokenized_captions_for_image: 67 | final_tokenized_captions_for_image[k] = [] 68 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 69 | if w not in PUNCTUATIONS]) 70 | final_tokenized_captions_for_image[k].append(tokenized_caption) 71 | 72 | return final_tokenized_captions_for_image 73 | -------------------------------------------------------------------------------- /eval/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /github_ignore_material/raw_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/github_ignore_material/raw_data/.gitkeep -------------------------------------------------------------------------------- /github_ignore_material/saves/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/github_ignore_material/saves/.gitkeep -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jia Cheng Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/losses/__init__.py -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LabelSmoothingLoss(nn.Module): 7 | def __init__(self, smoothing_coeff, rank='cuda:0'): 8 | assert 0.0 <= smoothing_coeff <= 1.0 9 | super().__init__() 10 | self.smoothing_coeff = smoothing_coeff 11 | self.kl_div = nn.KLDivLoss(reduction='none') 12 | self.log_softmax = nn.LogSoftmax(dim=-1) 13 | 14 | self.rank = rank 15 | 16 | def forward(self, pred, target, ignore_index, divide_by_non_zeros=True): 17 | pred = self.log_softmax(pred) 18 | 19 | batch_size, seq_len, num_classes = pred.shape 20 | uniform_confidence = self.smoothing_coeff / (num_classes - 1) # minus one cause of PAD token 21 | confidence = 1 - self.smoothing_coeff 22 | one_hot = torch.full((num_classes,), uniform_confidence).to(self.rank) 23 | model_prob = one_hot.repeat(batch_size, seq_len, 1) 24 | model_prob.scatter_(2, target.unsqueeze(2), confidence) 25 | model_prob.masked_fill_((target == ignore_index).unsqueeze(2), 0) 26 | 27 | tot_loss_tensor = self.kl_div(pred, model_prob) 28 | 29 | # divide the loss of each sequence by the number of non pads 30 | pads_matrix = torch.as_tensor(target == ignore_index) 31 | tot_loss_tensor.masked_fill_(pads_matrix.unsqueeze(2), 0.0) 32 | if divide_by_non_zeros: 33 | num_non_pads = (~pads_matrix).sum().type(torch.cuda.FloatTensor) 34 | tot_loss = tot_loss_tensor.sum() / num_non_pads 35 | else: 36 | tot_loss = tot_loss_tensor.sum() 37 | 38 | return tot_loss 39 | -------------------------------------------------------------------------------- /losses/reward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from eval.cider.reinforce_cider import ReinforceCider 3 | 4 | import itertools 5 | from utils import language_utils 6 | 7 | 8 | class ReinforceCiderReward: 9 | def __init__(self, training_references, eos_token, num_sampled_captions, rank): 10 | super(ReinforceCiderReward).__init__() 11 | self.rank = rank 12 | self.num_sampled_captions = num_sampled_captions 13 | 14 | preprocessed_training_references = [] 15 | for i in range(len(training_references)): 16 | preprocessed_captions = [] 17 | for caption in training_references[i]: 18 | # it's faster than invoking the tokenizer 19 | caption = language_utils.lowercase_and_clean_trailing_spaces([caption]) 20 | caption = language_utils.add_space_between_non_alphanumeric_symbols(caption) 21 | caption = language_utils.remove_punctuations(caption) 22 | caption = " ".join(caption[0].split() + [eos_token]) 23 | preprocessed_captions.append(caption) 24 | preprocessed_training_references.append(preprocessed_captions) 25 | self.training_references = preprocessed_training_references 26 | self.reinforce_cider = ReinforceCider(self.training_references) 27 | 28 | def compute_reward(self, all_images_pred_caption, all_images_logprob, all_images_idx, all_images_base_caption=None): 29 | 30 | batch_size = len(all_images_pred_caption) 31 | num_sampled_captions = len(all_images_pred_caption[0]) 32 | 33 | # Important for Correct and Fair results: keep EOS in the Log loss computation 34 | all_images_pred_caption = [' '.join(caption[1:]) 35 | for pred_one_image in all_images_pred_caption 36 | for caption in pred_one_image] 37 | 38 | # repeat the references for the number of outputs 39 | all_images_ref_caption = [self.training_references[idx] for idx in all_images_idx] 40 | all_images_ref_caption = list(itertools.chain.from_iterable(itertools.repeat(ref, self.num_sampled_captions) 41 | for ref in all_images_ref_caption)) 42 | 43 | cider_result = self.reinforce_cider.compute_score(hypo=all_images_pred_caption, 44 | refs=all_images_ref_caption) 45 | reward = torch.tensor(cider_result[1]).to(self.rank).view(batch_size, num_sampled_captions) 46 | 47 | 48 | if all_images_base_caption is None: # Mean base 49 | reward_base = (reward.sum(dim=-1, keepdim=True) - reward) / (num_sampled_captions - 1) 50 | else: # anything else like Greedy 51 | all_images_base_caption = [' '.join(caption[1:]) 52 | for pred_one_image in all_images_base_caption 53 | for caption in pred_one_image] 54 | 55 | base_cider_result = self.reinforce_cider.compute_score(hypo=all_images_base_caption, 56 | refs=all_images_ref_caption) 57 | reward_base = torch.tensor(base_cider_result[1]).to(self.rank).view(batch_size, num_sampled_captions) 58 | 59 | reward_loss = (reward - reward_base) * torch.sum(-all_images_logprob, dim=-1) 60 | reward_loss = reward_loss.mean() 61 | return reward_loss, reward, reward_base 62 | -------------------------------------------------------------------------------- /models/End_ExpansionNet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer 5 | from utils.masking import create_pad_mask, create_no_peak_and_pad_mask 6 | from models.captioning_model import CaptioningModel 7 | from models.swin_transformer_mod import SwinTransformer 8 | 9 | 10 | class End_ExpansionNet_v2(CaptioningModel): 11 | def __init__(self, 12 | 13 | # swin transf 14 | swin_img_size, swin_patch_size, swin_in_chans, 15 | swin_embed_dim, swin_depths, swin_num_heads, 16 | swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale, 17 | swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate, 18 | swin_norm_layer, swin_ape, swin_patch_norm, 19 | swin_use_checkpoint, 20 | 21 | # linear_size, 22 | final_swin_dim, 23 | 24 | # captioning 25 | d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, 26 | output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0): 27 | super(End_ExpansionNet_v2, self).__init__() 28 | 29 | self.swin_transf = SwinTransformer( 30 | img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans, 31 | embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads, 32 | window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale, 33 | drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate, 34 | norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm, 35 | use_checkpoint=swin_use_checkpoint) 36 | 37 | self.output_word2idx = output_word2idx 38 | self.output_idx2word = output_idx2word 39 | self.max_seq_len = max_seq_len 40 | 41 | self.num_exp_dec = num_exp_dec 42 | self.num_exp_enc_list = num_exp_enc_list 43 | 44 | self.N_enc = N_enc 45 | self.N_dec = N_dec 46 | self.d_model = d_model 47 | 48 | self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) 49 | self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) 50 | 51 | self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) 52 | self.input_linear = torch.nn.Linear(final_swin_dim, d_model) 53 | self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) 54 | self.log_softmax = nn.LogSoftmax(dim=-1) 55 | 56 | self.out_enc_dropout = nn.Dropout(drop_args.other) 57 | self.out_dec_dropout = nn.Dropout(drop_args.other) 58 | 59 | self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) 60 | self.pos_encoder = nn.Embedding(max_seq_len, d_model) 61 | 62 | self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) 63 | self.enc_reduce_norm = nn.LayerNorm(d_model) 64 | self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) 65 | self.dec_reduce_norm = nn.LayerNorm(d_model) 66 | 67 | for p in self.parameters(): 68 | if p.dim() > 1: 69 | nn.init.xavier_uniform_(p) 70 | 71 | self.trained_steps = 0 72 | self.rank = rank 73 | 74 | self.check_required_attributes() 75 | 76 | def forward_enc(self, enc_input, enc_input_num_pads): 77 | 78 | assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), \ 79 | "End to End case have no padding" 80 | x = self.swin_transf(enc_input) 81 | 82 | enc_input = self.input_embedder_dropout(self.input_linear(x)) 83 | x = enc_input 84 | enc_input_num_pads = [0] * enc_input.size(0) 85 | 86 | max_num_enc = sum(self.num_exp_enc_list) 87 | pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) 88 | pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), 89 | pad_row=[0] * enc_input.size(0), 90 | pad_column=enc_input_num_pads, 91 | rank=self.rank) 92 | 93 | x_list = [] 94 | for i in range(self.N_enc): 95 | x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) 96 | x_list.append(x) 97 | x_list = torch.cat(x_list, dim=-1) 98 | x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) 99 | x = self.enc_reduce_norm(x) 100 | 101 | return x 102 | 103 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): 104 | assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), \ 105 | "enc_input_num_pads should be no None" 106 | 107 | enc_input_num_pads = [0] * dec_input.size(0) 108 | no_peak_and_pad_mask = create_no_peak_and_pad_mask( 109 | mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), 110 | num_pads=dec_input_num_pads, 111 | rank=self.rank) 112 | pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), 113 | pad_row=dec_input_num_pads, 114 | pad_column=enc_input_num_pads, 115 | rank=self.rank) 116 | 117 | y = self.out_embedder(dec_input) 118 | pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) 119 | pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) 120 | y = y + self.pos_encoder(pos_y) 121 | y_list = [] 122 | for i in range(self.N_dec): 123 | y = self.decoders[i](x=y, 124 | n_indexes=pos_x, 125 | cross_connection_x=cross_input, 126 | input_attention_mask=no_peak_and_pad_mask, 127 | cross_attention_mask=pad_mask) 128 | y_list.append(y) 129 | y_list = torch.cat(y_list, dim=-1) 130 | y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) 131 | y = self.dec_reduce_norm(y) 132 | 133 | y = self.vocab_linear(y) 134 | 135 | if apply_log_softmax: 136 | y = self.log_softmax(y) 137 | 138 | return y 139 | 140 | 141 | def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, 142 | sos_idx, eos_idx, max_seq_len): 143 | 144 | bs = enc_input.size(0) 145 | x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) 146 | enc_seq_len = x.size(1) 147 | x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) 148 | 149 | upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) 150 | where_is_eos_vector = upperbound_vector.clone() 151 | eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) 152 | finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) 153 | 154 | predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) 155 | predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) 156 | 157 | dec_input_num_pads = [0]*(bs*num_outputs) 158 | time_step = 0 159 | while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: 160 | dec_input = predicted_caption 161 | log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) 162 | 163 | prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) 164 | sampled_word_indexes = prob_dist.sample() 165 | 166 | predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) 167 | predicted_caption_prob = torch.cat((predicted_caption_prob, 168 | log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) 169 | time_step += 1 170 | 171 | where_is_eos_vector = torch.min(where_is_eos_vector, 172 | upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) 173 | finished_flag_vector = torch.max(finished_flag_vector, 174 | (sampled_word_indexes == eos_vector).type(torch.IntTensor)) 175 | 176 | res_predicted_caption = [] 177 | for i in range(bs): 178 | res_predicted_caption.append([]) 179 | for j in range(num_outputs): 180 | index = i*num_outputs + j 181 | res_predicted_caption[i].append( 182 | predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) 183 | 184 | where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) 185 | arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) 186 | predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) 187 | res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) 188 | 189 | return res_predicted_caption, res_predicted_caption_prob 190 | -------------------------------------------------------------------------------- /models/ExpansionNet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer 3 | from utils.masking import create_pad_mask, create_no_peak_and_pad_mask 4 | from models.captioning_model import CaptioningModel 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class ExpansionNet_v2(CaptioningModel): 10 | def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, 11 | output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0): 12 | super().__init__() 13 | self.output_word2idx = output_word2idx 14 | self.output_idx2word = output_idx2word 15 | self.max_seq_len = max_seq_len 16 | 17 | self.num_exp_dec = num_exp_dec 18 | self.num_exp_enc_list = num_exp_enc_list 19 | 20 | self.N_enc = N_enc 21 | self.N_dec = N_dec 22 | self.d_model = d_model 23 | 24 | self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) 25 | self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) 26 | 27 | self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) 28 | self.input_linear = torch.nn.Linear(img_feature_dim, d_model) 29 | self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) 30 | self.log_softmax = nn.LogSoftmax(dim=-1) 31 | 32 | self.out_enc_dropout = nn.Dropout(drop_args.other) 33 | self.out_dec_dropout = nn.Dropout(drop_args.other) 34 | 35 | self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) 36 | self.pos_encoder = nn.Embedding(max_seq_len, d_model) 37 | 38 | self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) 39 | self.enc_reduce_norm = nn.LayerNorm(d_model) 40 | self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) 41 | self.dec_reduce_norm = nn.LayerNorm(d_model) 42 | 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | self.trained_steps = 0 48 | self.rank = rank 49 | 50 | def forward_enc(self, enc_input, enc_input_num_pads): 51 | 52 | x = self.input_embedder_dropout(self.input_linear(enc_input.float())) 53 | 54 | sum_num_enc = sum(self.num_exp_enc_list) 55 | pos_x = torch.arange(sum_num_enc).unsqueeze(0).expand(enc_input.size(0), sum_num_enc).to(self.rank) 56 | pad_mask = create_pad_mask(mask_size=(enc_input.size(0), sum_num_enc, enc_input.size(1)), 57 | pad_row=[0] * enc_input.size(0), 58 | pad_column=enc_input_num_pads, 59 | rank=self.rank) 60 | 61 | x_list = [] 62 | for i in range(self.N_enc): 63 | x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) 64 | x_list.append(x) 65 | x_list = torch.cat(x_list, dim=-1) 66 | x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) 67 | x = self.enc_reduce_norm(x) 68 | return x 69 | 70 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): 71 | 72 | no_peak_and_pad_mask = create_no_peak_and_pad_mask( 73 | mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), 74 | num_pads=torch.tensor(dec_input_num_pads), 75 | rank=self.rank) 76 | 77 | pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), 78 | pad_row=torch.tensor(dec_input_num_pads), 79 | pad_column=torch.tensor(enc_input_num_pads), 80 | rank=self.rank) 81 | 82 | y = self.out_embedder(dec_input) 83 | pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) 84 | pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) 85 | y = y + self.pos_encoder(pos_y) 86 | y_list = [] 87 | for i in range(self.N_dec): 88 | y = self.decoders[i](x=y, 89 | n_indexes=pos_x, 90 | cross_connection_x=cross_input, 91 | input_attention_mask=no_peak_and_pad_mask, 92 | cross_attention_mask=pad_mask) 93 | y_list.append(y) 94 | y_list = torch.cat(y_list, dim=-1) 95 | y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) 96 | y = self.dec_reduce_norm(y) 97 | 98 | y = self.vocab_linear(y) 99 | 100 | if apply_log_softmax: 101 | y = self.log_softmax(y) 102 | 103 | return y 104 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/models/__init__.py -------------------------------------------------------------------------------- /models/captioning_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CaptioningModel(nn.Module): 8 | def __init__(self): 9 | super(CaptioningModel, self).__init__() 10 | # mandatory attributes 11 | # rank: to enable multiprocessing 12 | self.rank = None 13 | 14 | def check_required_attributes(self): 15 | if self.rank is None: 16 | raise NotImplementedError("Subclass must assign the rank integer according to the GPU group") 17 | 18 | def forward_enc(self, enc_input, enc_input_num_pads): 19 | raise NotImplementedError 20 | 21 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): 22 | raise NotImplementedError 23 | 24 | def forward(self, enc_x, dec_x=None, 25 | enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, 26 | mode='forward', **kwargs): 27 | if mode == 'forward': 28 | x = self.forward_enc(enc_x, enc_x_num_pads) 29 | y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax) 30 | return y 31 | else: 32 | assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \ 33 | 'sos and eos must be provided in case of batch sampling or beam search' 34 | sos_idx = kwargs.get('sos_idx', -999) 35 | eos_idx = kwargs.get('eos_idx', -999) 36 | if mode == 'beam_search': 37 | beam_size_arg = kwargs.get('beam_size', 5) 38 | how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) 39 | beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) 40 | sample_or_max = kwargs.get('sample_or_max', 'max') 41 | out_classes, out_logprobs = self.beam_search( 42 | enc_x, enc_x_num_pads, 43 | beam_size=beam_size_arg, 44 | sos_idx=sos_idx, 45 | eos_idx=eos_idx, 46 | how_many_outputs=how_many_outputs_per_beam, 47 | max_seq_len=beam_max_seq_len, 48 | sample_or_max=sample_or_max) 49 | return out_classes, out_logprobs 50 | if mode == 'sampling': 51 | how_many_outputs = kwargs.get('how_many_outputs', 1) 52 | sample_max_seq_len = kwargs.get('sample_max_seq_len', 20) 53 | out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction( 54 | enc_x, enc_x_num_pads, num_outputs=how_many_outputs, 55 | sos_idx=sos_idx, eos_idx=eos_idx, 56 | max_seq_len=sample_max_seq_len) 57 | return out_classes, out_logprobs 58 | 59 | def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, 60 | sos_idx, eos_idx, max_seq_len): 61 | bs, enc_seq_len, _ = enc_input.shape 62 | 63 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)] 64 | 65 | x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) 66 | x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) 67 | 68 | upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) 69 | where_is_eos_vector = upperbound_vector.clone() 70 | eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) 71 | finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) 72 | 73 | predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) 74 | predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) 75 | 76 | dec_input_num_pads = [0]*(bs*num_outputs) 77 | time_step = 0 78 | while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: 79 | dec_input = predicted_caption 80 | log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) 81 | 82 | prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) 83 | sampled_word_indexes = prob_dist.sample() 84 | 85 | predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) 86 | predicted_caption_prob = torch.cat((predicted_caption_prob, 87 | log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) 88 | time_step += 1 89 | 90 | where_is_eos_vector = torch.min(where_is_eos_vector, 91 | upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) 92 | finished_flag_vector = torch.max(finished_flag_vector, 93 | (sampled_word_indexes == eos_vector).type(torch.IntTensor)) 94 | 95 | # remove the elements that come after the first eos from the sequence 96 | res_predicted_caption = [] 97 | for i in range(bs): 98 | res_predicted_caption.append([]) 99 | for j in range(num_outputs): 100 | index = i*num_outputs + j 101 | res_predicted_caption[i].append( 102 | predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) 103 | 104 | where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) 105 | arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) 106 | predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) 107 | res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) 108 | 109 | return res_predicted_caption, res_predicted_caption_prob 110 | 111 | def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, 112 | beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): 113 | """ 114 | TO-DO (Maybe?): The code is not very elegant (can be shorter) and optimized. 115 | E.g. caching can be implemented for a slight inference time improvement. 116 | However, I think it would become less readable / friendly, so not sure if it is worth it. 117 | """ 118 | assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" 119 | assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" 120 | bs = enc_input.shape[0] 121 | 122 | cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads) 123 | 124 | # init: ------------------------------------------------------------------ 125 | init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) 126 | init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) 127 | log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, 128 | dec_input=init_dec_class, dec_input_num_pads=[0] * bs, 129 | apply_log_softmax=True) 130 | if sample_or_max == 'max': 131 | _, topi = torch.topk(log_probs, k=beam_size, sorted=True) 132 | else: # sample 133 | topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) 134 | topi = topi.unsqueeze(1) 135 | 136 | init_dec_class = init_dec_class.repeat(1, beam_size) 137 | init_dec_class = init_dec_class.unsqueeze(-1) 138 | top_beam_size_class = topi.transpose(-2, -1) 139 | init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) 140 | 141 | init_dec_logprob = init_dec_logprob.repeat(1, beam_size) 142 | init_dec_logprob = init_dec_logprob.unsqueeze(-1) 143 | top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) 144 | top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) 145 | init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) 146 | 147 | bs, enc_seq_len, d_model = cross_enc_output.shape 148 | cross_enc_output = cross_enc_output.unsqueeze(1) 149 | cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) 150 | cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() 151 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] 152 | 153 | # loop: ----------------------------------------------------------------- 154 | loop_dec_classes = init_dec_class 155 | loop_dec_logprobs = init_dec_logprob 156 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) 157 | 158 | loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) 159 | 160 | for time_step in range(2, max_seq_len): 161 | loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() 162 | 163 | log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, 164 | dec_input=loop_dec_classes, 165 | dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), 166 | apply_log_softmax=True) 167 | if sample_or_max == 'max': 168 | _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) 169 | else: # sample 170 | topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, 171 | replacement=False) 172 | 173 | top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) 174 | 175 | top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) 176 | top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) 177 | 178 | # each sequence have now its best prediction, but some sequence may have already been terminated with EOS, 179 | # in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value 180 | # are set to zero 181 | there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ 182 | sum(dim=-1, keepdims=True).type(torch.bool) 183 | 184 | # if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be 185 | # straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0 186 | # one arbitrary candidate word probability so the sequence probability is unchanged but it 187 | # can still be discarded when a better candidate sequence is found 188 | top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) 189 | top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) 190 | 191 | comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs 192 | 193 | comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) 194 | _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) 195 | which_sequence = topi // beam_size 196 | which_word = topi % beam_size 197 | 198 | loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) 199 | loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) 200 | 201 | bs_idxes = torch.arange(bs).unsqueeze(-1) 202 | new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] 203 | new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] 204 | 205 | which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] 206 | which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ 207 | [bs_idxes, which_sequence]] 208 | which_word = which_word.unsqueeze(-1) 209 | 210 | lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, 211 | index=which_word) 212 | lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) 213 | 214 | new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) 215 | new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) 216 | loop_dec_classes = new_loop_dec_classes 217 | loop_dec_logprobs = new_loop_dec_logprobs 218 | 219 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) 220 | 221 | # -----------------------update loop_num_elem_vector ---------------------------- 222 | loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) 223 | there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ 224 | sum(dim=-1).type(torch.bool).view(bs * beam_size) 225 | loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) 226 | 227 | if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): 228 | break 229 | 230 | # sort out the best result 231 | loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) 232 | _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) 233 | res_caption_pred = [[] for _ in range(bs)] 234 | res_caption_logprob = [[] for _ in range(bs)] 235 | for i in range(bs): 236 | for j in range(how_many_outputs): 237 | idx = topi[i, j].item() 238 | res_caption_pred[i].append( 239 | loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) 240 | res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) 241 | 242 | flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] 243 | flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) 244 | res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) 245 | 246 | return res_caption_pred, res_caption_logprob 247 | -------------------------------------------------------------------------------- /models/ensemble_captioning_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from models.captioning_model import CaptioningModel 5 | 6 | 7 | class EsembleCaptioningModel(CaptioningModel): 8 | def __init__(self, models_list, rank): 9 | super().__init__() 10 | self.num_models = len(models_list) 11 | self.models_list = models_list 12 | self.rank = rank 13 | 14 | self.dummy_linear = nn.Linear(1, 1) 15 | 16 | for model in self.models_list: 17 | model.eval() 18 | 19 | def forward(self, enc_x, dec_x=None, 20 | enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, 21 | mode='beam_search', **kwargs): 22 | assert (mode == 'beam_search'), "this class supports only beam search." 23 | sos_idx = kwargs.get('sos_idx', -999) 24 | eos_idx = kwargs.get('eos_idx', -999) 25 | if mode == 'beam_search': 26 | beam_size_arg = kwargs.get('beam_size', 5) 27 | how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) 28 | beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) 29 | sample_or_max = kwargs.get('sample_or_max', 'max') 30 | out_classes, out_logprobs = self.ensemble_beam_search( 31 | enc_x, enc_x_num_pads, 32 | beam_size=beam_size_arg, 33 | sos_idx=sos_idx, 34 | eos_idx=eos_idx, 35 | how_many_outputs=how_many_outputs_per_beam, 36 | max_seq_len=beam_max_seq_len, 37 | sample_or_max=sample_or_max) 38 | return out_classes, out_logprobs 39 | 40 | def forward_enc(self, enc_input, enc_input_num_pads): 41 | x_outputs_list = [] 42 | for i in range(self.num_models): 43 | x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads) 44 | x_outputs_list.append(x_outputs) 45 | return x_outputs_list 46 | 47 | def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): 48 | 49 | import torch.nn.functional as F 50 | y_outputs = [] 51 | for i in range(self.num_models): 52 | y_outputs.append( 53 | F.softmax(self.models_list[i].forward_dec( 54 | cross_input_list[i], enc_input_num_pads, 55 | dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1)) 56 | avg = torch.cat(y_outputs, dim=0).mean(dim=0).log() 57 | 58 | return avg 59 | 60 | # quite unclean coding, to be re-factored in the future... 61 | # since it's a bit similar to the single model case 62 | def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, 63 | beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): 64 | assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" 65 | assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" 66 | bs = enc_input.shape[0] 67 | 68 | # the cross_dec_input is computed once 69 | cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads) 70 | 71 | # init: ------------------------------------------------------------------ 72 | init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) 73 | init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) 74 | log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, 75 | dec_input=init_dec_class, dec_input_num_pads=[0] * bs, 76 | apply_log_softmax=True) 77 | if sample_or_max == 'max': 78 | _, topi = torch.topk(log_probs, k=beam_size, sorted=True) 79 | else: # sample 80 | topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) 81 | topi = topi.unsqueeze(1) 82 | 83 | init_dec_class = init_dec_class.repeat(1, beam_size) 84 | init_dec_class = init_dec_class.unsqueeze(-1) 85 | top_beam_size_class = topi.transpose(-2, -1) 86 | init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) 87 | 88 | init_dec_logprob = init_dec_logprob.repeat(1, beam_size) 89 | init_dec_logprob = init_dec_logprob.unsqueeze(-1) 90 | top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) 91 | top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) 92 | init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) 93 | 94 | tmp_cross_enc_output_list = [] 95 | for cross_enc_output in cross_enc_output_list: 96 | bs, enc_seq_len, d_model = cross_enc_output.shape 97 | cross_enc_output = cross_enc_output.unsqueeze(1) 98 | cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) 99 | cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() 100 | tmp_cross_enc_output_list.append(cross_enc_output) 101 | cross_enc_output_list = tmp_cross_enc_output_list 102 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] 103 | 104 | loop_dec_classes = init_dec_class 105 | loop_dec_logprobs = init_dec_logprob 106 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) 107 | 108 | loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) 109 | 110 | for time_step in range(2, max_seq_len): 111 | loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() 112 | 113 | log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, 114 | dec_input=loop_dec_classes, 115 | dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), 116 | apply_log_softmax=True) 117 | if sample_or_max == 'max': 118 | _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) 119 | else: # sample 120 | topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, 121 | replacement=False) 122 | 123 | top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) 124 | 125 | top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) 126 | top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) 127 | 128 | there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ 129 | sum(dim=-1, keepdims=True).type(torch.bool) 130 | 131 | top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) 132 | top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) 133 | 134 | comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs 135 | 136 | comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) 137 | _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) 138 | which_sequence = topi // beam_size 139 | which_word = topi % beam_size 140 | 141 | loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) 142 | loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) 143 | 144 | bs_idxes = torch.arange(bs).unsqueeze(-1) 145 | new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] 146 | new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] 147 | 148 | which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] 149 | which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ 150 | [bs_idxes, which_sequence]] 151 | which_word = which_word.unsqueeze(-1) 152 | 153 | lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, 154 | index=which_word) 155 | lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) 156 | 157 | new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) 158 | new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) 159 | loop_dec_classes = new_loop_dec_classes 160 | loop_dec_logprobs = new_loop_dec_logprobs 161 | 162 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) 163 | 164 | loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) 165 | there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ 166 | sum(dim=-1).type(torch.bool).view(bs * beam_size) 167 | loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) 168 | 169 | if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): 170 | break 171 | 172 | loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) 173 | _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) 174 | res_caption_pred = [[] for _ in range(bs)] 175 | res_caption_logprob = [[] for _ in range(bs)] 176 | for i in range(bs): 177 | for j in range(how_many_outputs): 178 | idx = topi[i, j].item() 179 | res_caption_pred[i].append( 180 | loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) 181 | res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) 182 | 183 | flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] 184 | flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) 185 | res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) 186 | 187 | return res_caption_pred, res_caption_logprob 188 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | class EmbeddingLayer(nn.Module): 10 | def __init__(self, vocab_size, d_model, dropout_perc): 11 | super(EmbeddingLayer, self).__init__() 12 | self.dropout = nn.Dropout(dropout_perc) 13 | self.embed = nn.Embedding(vocab_size, d_model) 14 | self.d_model = d_model 15 | 16 | def forward(self, x): 17 | return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model)) 18 | 19 | 20 | class StaticExpansionBlock(nn.Module): 21 | def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps): 22 | super().__init__() 23 | self.d_model = d_model 24 | self.num_enc_exp_list = num_enc_exp_list 25 | 26 | self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) 27 | self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) 28 | 29 | self.key_embed = nn.Linear(d_model, d_model) 30 | self.class_a_embed = nn.Linear(d_model, d_model) 31 | self.class_b_embed = nn.Linear(d_model, d_model) 32 | 33 | self.selector_embed = nn.Linear(d_model, d_model) 34 | 35 | self.dropout_class_a_fw = nn.Dropout(dropout_perc) 36 | self.dropout_class_b_fw = nn.Dropout(dropout_perc) 37 | 38 | self.dropout_class_a_bw = nn.Dropout(dropout_perc) 39 | self.dropout_class_b_bw = nn.Dropout(dropout_perc) 40 | 41 | self.Z_dropout = nn.Dropout(dropout_perc) 42 | 43 | self.eps = eps 44 | 45 | def forward(self, x, n_indexes, mask): 46 | bs, enc_len, _ = x.shape 47 | 48 | query_exp = self.query_exp_vectors(n_indexes) 49 | bias_exp = self.bias_exp_vectors(n_indexes) 50 | x_key = self.key_embed(x) 51 | 52 | z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / (self.d_model ** 0.5) 53 | z = self.Z_dropout(z) 54 | 55 | class_a_fw = F.relu(z) 56 | class_b_fw = F.relu(-z) 57 | class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0) 58 | class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0) 59 | class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) 60 | class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) 61 | 62 | class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp 63 | class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp 64 | class_a = self.dropout_class_a_fw(class_a) 65 | class_b = self.dropout_class_b_fw(class_b) 66 | 67 | class_a_bw = F.relu(z.transpose(-2, -1)) 68 | class_b_bw = F.relu(-z.transpose(-2, -1)) 69 | 70 | accum = 0 71 | class_a_bw_list = [] 72 | class_b_bw_list = [] 73 | for j in range(len(self.num_enc_exp_list)): 74 | from_idx = accum 75 | to_idx = accum + self.num_enc_exp_list[j] 76 | accum += self.num_enc_exp_list[j] 77 | class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) 78 | class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) 79 | class_a_bw = torch.cat(class_a_bw_list, dim=-1) 80 | class_b_bw = torch.cat(class_b_bw_list, dim=-1) 81 | 82 | class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list) 83 | class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list) 84 | class_a = self.dropout_class_a_bw(class_a) 85 | class_b = self.dropout_class_b_bw(class_b) 86 | 87 | selector = torch.sigmoid(self.selector_embed(x)) 88 | x_result = selector * class_a + (1 - selector) * class_b 89 | 90 | return x_result 91 | 92 | 93 | class EncoderLayer(nn.Module): 94 | def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9): 95 | super().__init__() 96 | self.norm_1 = nn.LayerNorm(d_model) 97 | self.norm_2 = nn.LayerNorm(d_model) 98 | self.dropout_1 = nn.Dropout(dropout_perc) 99 | self.dropout_2 = nn.Dropout(dropout_perc) 100 | 101 | self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps) 102 | self.ff = FeedForward(d_model, d_ff, dropout_perc) 103 | 104 | def forward(self, x, n_indexes, mask): 105 | x2 = self.norm_1(x) 106 | x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask)) 107 | x2 = self.norm_2(x) 108 | x = x + self.dropout_2(self.ff(x2)) 109 | return x 110 | 111 | 112 | class DynamicExpansionBlock(nn.Module): 113 | def __init__(self, d_model, num_exp, dropout_perc, eps): 114 | super().__init__() 115 | self.d_model = d_model 116 | 117 | self.num_exp = num_exp 118 | self.cond_embed = nn.Linear(d_model, d_model) 119 | 120 | self.query_exp_vectors = nn.Embedding(self.num_exp, d_model) 121 | self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model) 122 | 123 | self.key_linear = nn.Linear(d_model, d_model) 124 | self.class_a_embed = nn.Linear(d_model, d_model) 125 | self.class_b_embed = nn.Linear(d_model, d_model) 126 | 127 | self.selector_embed = nn.Linear(d_model, d_model) 128 | 129 | self.dropout_class_a_fw = nn.Dropout(dropout_perc) 130 | self.dropout_class_b_fw = nn.Dropout(dropout_perc) 131 | self.dropout_class_a_bw = nn.Dropout(dropout_perc) 132 | self.dropout_class_b_bw = nn.Dropout(dropout_perc) 133 | 134 | self.Z_dropout = nn.Dropout(dropout_perc) 135 | 136 | self.eps = eps 137 | 138 | def forward(self, x, n_indexes, mask): 139 | bs, dec_len, _ = x.shape 140 | 141 | cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model) 142 | query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1) 143 | bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1) 144 | query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) 145 | bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) 146 | 147 | x_key = self.key_linear(x) 148 | z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / (self.d_model ** 0.5) 149 | z = self.Z_dropout(z) 150 | 151 | mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \ 152 | view(bs, dec_len * self.num_exp, dec_len) 153 | 154 | class_a_fw = F.relu(z) 155 | class_b_fw = F.relu(-z) 156 | class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0) 157 | class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0) 158 | class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) 159 | class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) 160 | class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) 161 | class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) 162 | class_a = self.dropout_class_a_fw(class_a) 163 | class_b = self.dropout_class_b_fw(class_b) 164 | 165 | mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \ 166 | view(bs, dec_len, dec_len * self.num_exp) 167 | 168 | class_a_bw = F.relu(z.transpose(-2, -1)) 169 | class_b_bw = F.relu(-z.transpose(-2, -1)) 170 | class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0) 171 | class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0) 172 | class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps) 173 | class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps) 174 | class_a = torch.matmul(class_a_bw, class_a + bias_exp) 175 | class_b = torch.matmul(class_b_bw, class_b + bias_exp) 176 | class_a = self.dropout_class_a_bw(class_a) 177 | class_b = self.dropout_class_b_bw(class_b) 178 | 179 | selector = torch.sigmoid(self.selector_embed(x)) 180 | x_result = selector * class_a + (1 - selector) * class_b 181 | 182 | return x_result 183 | 184 | 185 | class DecoderLayer(nn.Module): 186 | def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9): 187 | super().__init__() 188 | self.norm_1 = nn.LayerNorm(d_model) 189 | self.norm_2 = nn.LayerNorm(d_model) 190 | self.norm_3 = nn.LayerNorm(d_model) 191 | 192 | self.dropout_1 = nn.Dropout(dropout_perc) 193 | self.dropout_2 = nn.Dropout(dropout_perc) 194 | self.dropout_3 = nn.Dropout(dropout_perc) 195 | 196 | self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc) 197 | self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps) 198 | self.ff = FeedForward(d_model, d_ff, dropout_perc) 199 | 200 | def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask): 201 | 202 | # Pre-LayerNorm 203 | x2 = self.norm_1(x) 204 | x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask)) 205 | 206 | x2 = self.norm_2(x) 207 | x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x, 208 | mask=cross_attention_mask)) 209 | 210 | x2 = self.norm_3(x) 211 | x = x + self.dropout_3(self.ff(x2)) 212 | return x 213 | 214 | 215 | 216 | class MultiHeadAttention(nn.Module): 217 | def __init__(self, d_model, num_heads, dropout_perc): 218 | super(MultiHeadAttention, self).__init__() 219 | assert d_model % num_heads == 0, "num heads must be multiple of d_model" 220 | 221 | self.d_model = d_model 222 | self.d_k = int(d_model / num_heads) 223 | self.num_heads = num_heads 224 | 225 | self.Wq = nn.Linear(d_model, self.d_k * num_heads) 226 | self.Wk = nn.Linear(d_model, self.d_k * num_heads) 227 | self.Wv = nn.Linear(d_model, self.d_k * num_heads) 228 | 229 | self.out_linear = nn.Linear(d_model, d_model) 230 | 231 | def forward(self, q, k, v, mask=None): 232 | batch_size, q_seq_len, _ = q.shape 233 | k_seq_len = k.size(1) 234 | v_seq_len = v.size(1) 235 | 236 | k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k) 237 | q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k) 238 | v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k) 239 | 240 | k_proj = k_proj.transpose(2, 1) 241 | q_proj = q_proj.transpose(2, 1) 242 | v_proj = v_proj.transpose(2, 1) 243 | 244 | sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2)) 245 | sim_scores = sim_scores / self.d_k ** 0.5 246 | 247 | if mask is not None: 248 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) 249 | sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4) 250 | sim_scores = F.softmax(input=sim_scores, dim=-1) 251 | 252 | attention_applied = torch.matmul(sim_scores, v_proj) 253 | attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\ 254 | .view(batch_size, q_seq_len, self.d_model) 255 | 256 | out = self.out_linear(attention_applied_concatenated) 257 | return out 258 | 259 | 260 | class FeedForward(nn.Module): 261 | def __init__(self, d_model, d_ff, dropout_perc): 262 | super(FeedForward, self).__init__() 263 | self.linear_1 = nn.Linear(d_model, d_ff) 264 | self.dropout = nn.Dropout(dropout_perc) 265 | self.linear_2 = nn.Linear(d_ff, d_model) 266 | 267 | def forward(self, x): 268 | x = self.dropout(F.relu(self.linear_1(x))) 269 | x = self.linear_2(x) 270 | return x 271 | -------------------------------------------------------------------------------- /onnx4tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Although this package provide the scripts and implementations to 3 | support ONNX conversion and deploy on TensorRT. 4 | 5 | It kinda breaks the "DRY" principle" but the the model and backbone need some tweak 6 | and more attention than the pure pytorch counterpart. Since issue related to the 7 | particular contexts of ONNX and TensorRT may be raised in the future, it's safer 8 | to separate the two versions for the moment. 9 | """ -------------------------------------------------------------------------------- /onnx4tensorrt/convert2onnx.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import argparse 5 | import pickle 6 | import copy 7 | import onnx 8 | from argparse import Namespace 9 | 10 | 11 | import functools 12 | print = functools.partial(print, flush=True) 13 | 14 | from utils.saving_utils import partially_load_state_dict 15 | from utils.image_utils import preprocess_image 16 | from utils.language_utils import tokens2description 17 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import create_pad_mask, create_no_peak_and_pad_mask 18 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import End_ExpansionNet_v2_ONNX_TensorRT, \ 19 | NUM_FEATURES, MAX_DECODE_STEPS 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description='Conversion PyTorch to ONNX') 24 | parser.add_argument('--model_dim', type=int, default=512) 25 | parser.add_argument('--N_enc', type=int, default=3) 26 | parser.add_argument('--N_dec', type=int, default=3) 27 | parser.add_argument('--vocab_path', type=str, default='./demo_material/demo_coco_tokens.pickle') 28 | parser.add_argument('--max_seq_len', type=int, default=74) 29 | parser.add_argument('--image_path_1', type=str, default='./demo_material/tatin.jpg') 30 | parser.add_argument('--image_path_2', type=str, default='./demo_material/micheal.jpg') 31 | parser.add_argument('--load_model_path', type=str, default='./github_ignore_material/saves/rf_model.pth') 32 | parser.add_argument('--output_onnx_path', type=str, default='./rf_model.onnx') 33 | parser.add_argument('--onnx_simplify', type=bool, default=False) 34 | parser.add_argument('--onnx_runtime_test', type=bool, default=False) 35 | parser.add_argument('--onnx_tensorrt_test', type=bool, default=False) 36 | parser.add_argument('--max_worker_size', type=int, default=10000) 37 | parser.add_argument('--onnx_opset', type=int, default=14) 38 | # parser.add_argument('--beam_size', type=int, default=5) 39 | 40 | args = parser.parse_args() 41 | 42 | with open(args.vocab_path, 'rb') as f: 43 | coco_tokens = pickle.load(f) 44 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']] 45 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']] 46 | 47 | # test the generalization of the graph using two images 48 | img_size = 384 49 | image_1 = preprocess_image(args.image_path_1, img_size) 50 | image_2 = preprocess_image(args.image_path_2, img_size) 51 | 52 | drop_args = Namespace(enc=0.0, 53 | dec=0.0, 54 | enc_input=0.0, 55 | dec_input=0.0, 56 | other=0.0) 57 | enc_exp_list = [32, 64, 128, 256, 512] 58 | dec_exp = 16 59 | num_heads = 8 60 | model = End_ExpansionNet_v2_ONNX_TensorRT( 61 | swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, 62 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], 63 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, 64 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, 65 | swin_norm_layer=torch.nn.LayerNorm, swin_patch_norm=True, 66 | 67 | d_model=args.model_dim, N_enc=args.N_enc, 68 | N_dec=args.N_dec, num_heads=num_heads, ff=2048, 69 | num_exp_enc_list=enc_exp_list, 70 | num_exp_dec=16, 71 | output_word2idx=coco_tokens['word2idx_dict'], 72 | output_idx2word=coco_tokens['idx2word_list'], 73 | max_seq_len=args.max_seq_len, drop_args=drop_args) 74 | 75 | print("Loading model...") 76 | checkpoint = torch.load(args.load_model_path) 77 | partially_load_state_dict(model, checkpoint['model_state_dict']) 78 | 79 | print("===============================================") 80 | print("|| ONNX conversion ||") 81 | print("===============================================") 82 | 83 | # Masks creation - - - - - 84 | batch_size = 1 85 | enc_mask = create_pad_mask(mask_size=(batch_size, sum(enc_exp_list), NUM_FEATURES), 86 | pad_row=[0], pad_column=[0]).contiguous() 87 | no_peak_mask = create_no_peak_and_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS), 88 | num_pads=[0]).contiguous() 89 | cross_mask = create_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, NUM_FEATURES), 90 | pad_row=[0], pad_column=[0]).contiguous() 91 | # contrary to the other masks, we put 1 in correspondence to the values to be masked 92 | cross_mask = 1 - cross_mask 93 | 94 | fw_dec_mask = no_peak_mask.unsqueeze(2).expand(batch_size, MAX_DECODE_STEPS, 95 | dec_exp, MAX_DECODE_STEPS).contiguous(). \ 96 | view(batch_size, MAX_DECODE_STEPS * dec_exp, MAX_DECODE_STEPS) 97 | 98 | bw_dec_mask = no_peak_mask.unsqueeze(-1).expand(batch_size, 99 | MAX_DECODE_STEPS, MAX_DECODE_STEPS, dec_exp).contiguous(). \ 100 | view(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS * dec_exp) 101 | 102 | atten_mask = cross_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 103 | 104 | # - - - - - - - - - - 105 | 106 | print("Exporting...") 107 | model.eval() 108 | my_script = torch.jit.script(model) 109 | torch.onnx.export( 110 | my_script, 111 | (image_1, torch.tensor([sos_idx]), enc_mask, fw_dec_mask, bw_dec_mask, atten_mask), 112 | args.output_onnx_path, 113 | input_names=['enc_x', 'sos_idx', 'enc_mask', 'fw_dec_mask', 'bw_dec_mask', 'cross_mask'], 114 | output_names=['pred', 'logprobs'], 115 | export_params=True, 116 | opset_version=args.onnx_opset) 117 | print("ONNX graph conversion done. Destination: " + args.output_onnx_path) 118 | onnx_model_fp32 = onnx.load(args.output_onnx_path) 119 | onnx.checker.check_model(onnx_model_fp32) 120 | print("ONNX graph checked.") 121 | 122 | # TO-DO: FP16 code was written but does not work for reason yet unknown, 123 | # tests were made on model trained in FP32 124 | # what if the model was trained in FP16 instead? 125 | 126 | # from onnxconverter_common import float16 127 | # onnx_model_fp16 = float16.convert_float_to_float16(onnx_model_fp32, op_block_list=["Topk", "Normalizer"]) 128 | # onnx.save(onnx_model_fp16, args.output_onnx_path + '_fp16') 129 | # print("ONNX graph FP16 version done. Destination: " + args.output_onnx_path + '_fp16') 130 | 131 | if args.onnx_simplify: 132 | print("===============================================") 133 | print("|| ONNX graph simplifcation phase ||") 134 | print("===============================================") 135 | from onnxsim import simplify 136 | onnx_model_fp32 = onnx.load(args.output_onnx_path) 137 | # onnx_model_fp16 = onnx.load(args.output_onnx_path + '_fp16') 138 | try: 139 | simplified_onnx_model_fp32, check_fp32 = simplify(onnx_model_fp32) 140 | # simplified_onnx_model_fp16, check_fp16 = simplify(onnx_model_fp16) 141 | except: 142 | print("The simplification failed. In this case, we suggest to try the command line version.") 143 | assert check_fp32, "Simplified fp32 ONNX model could not be validated" 144 | # assert check_fp16, "Simplified fp16 ONNX model could not be validated" 145 | onnx.save(simplified_onnx_model_fp32, args.output_onnx_path) 146 | # onnx.save(simplified_onnx_model_fp16, args.output_onnx_path + '_fp16') 147 | 148 | if True: #args.onnx_runtime_test: 149 | import onnxruntime as ort 150 | print("===============================================") 151 | print("|| Testing on ONNX Runtime ||") 152 | print("===============================================") 153 | 154 | ort_sess = ort.InferenceSession(args.output_onnx_path) 155 | input_dict_1 = {'enc_x': image_1.numpy(), 'sos_idx': np.array([sos_idx]), 156 | 'enc_mask': enc_mask.numpy(), 'fw_dec_mask': fw_dec_mask.numpy(), 157 | 'bw_dec_mask': bw_dec_mask.numpy(), 'cross_mask': atten_mask.numpy()} 158 | outputs_ort = ort_sess.run(None, input_dict_1) 159 | output_caption = tokens2description(outputs_ort[0][0].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx) 160 | print("ONNX Runtime result on 1st image:\n\t\t" + output_caption) 161 | 162 | input_dict_2 = copy.copy(input_dict_1) 163 | input_dict_2['enc_x'] = image_2.numpy() 164 | outputs_ort = ort_sess.run(None, input_dict_2) 165 | output_caption = tokens2description(outputs_ort[0][0].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx) 166 | print("ONNX Runtime result on 2nd image:\n\t\t" + output_caption) 167 | print("Done.", end="\n\n") 168 | 169 | if args.onnx_tensorrt_test: 170 | import onnx_tensorrt.backend as backend 171 | print("===============================================") 172 | print("|| Testing on ONNX-TensorRT backend ||") 173 | print("===============================================") 174 | 175 | engine = backend.prepare(onnx_model, device='CUDA:0', max_worker_size=args.max_worker_size) 176 | 177 | input_data = [image_1.numpy(), np.array([sos_idx]), 178 | enc_mask.numpy(), fw_dec_mask.numpy(), bw_dec_mask.numpy(), atten_mask.numpy()] 179 | output_data = engine.run(input_data)[0][0] 180 | output_caption = tokens2description(output_data.tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx) 181 | print("TensorRT result on 1st image:\n\t\t" + output_caption) 182 | 183 | input_data[0] = image_2.numpy() 184 | output_data = engine.run(input_data)[0][0] 185 | output_caption = tokens2description(output_data.tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx) 186 | print("TensorRT result on 2nd image:\n\t\t" + output_caption) 187 | print("Done.", end="\n\n") 188 | 189 | print("Closing.") 190 | 191 | -------------------------------------------------------------------------------- /onnx4tensorrt/onnx2tensorrt.py: -------------------------------------------------------------------------------- 1 | # Script for ONNX 2 Tensorrt conversion 2 | # credits to: Shakhizat Nurgaliyev (https://github.com/shahizat) 3 | 4 | import os 5 | import torch 6 | import argparse 7 | import numpy as np 8 | import pickle 9 | import tensorrt as trt 10 | import pycuda.driver as cuda 11 | import pycuda.autoinit # this is important 12 | 13 | from utils.args_utils import str2type 14 | from utils.language_utils import tokens2description 15 | from utils.image_utils import preprocess_image 16 | 17 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import create_pad_mask, create_no_peak_and_pad_mask 18 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import NUM_FEATURES, MAX_DECODE_STEPS 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description='ONNX 2 TensorRT') 23 | parser.add_argument('--onnx_path', type=str, default='./rf_model.onnx') 24 | parser.add_argument('--engine_path', type=str, default='./model_engine.trt') 25 | parser.add_argument('--data_type', type=str2type, default='fp32') 26 | args = parser.parse_args() 27 | 28 | with open('./demo_material/demo_coco_tokens.pickle', 'rb') as f: 29 | coco_tokens = pickle.load(f) 30 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']] 31 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']] 32 | 33 | # Build TensorRT engine 34 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING) 35 | trt_runtime = trt.Runtime(TRT_LOGGER) 36 | 37 | # The Onnx path is used for Onnx models. 38 | def build_engine_onnx(onnx_file_path, data_type): 39 | builder = trt.Builder(TRT_LOGGER) 40 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 41 | config = builder.create_builder_config() 42 | parser = trt.OnnxParser(network, TRT_LOGGER) 43 | 44 | builder.max_batch_size = 1 45 | config.max_workspace_size = 1 << 32 # sono in bytes quindi 2^32 -> 4GB 46 | # Load the Onnx model and parse it in order to populate the TensorRT network. 47 | 48 | if data_type == 'fp32': 49 | pass # do nothing, is the default 50 | elif data_type == 'fp16': 51 | if builder.platform_has_fast_fp16: 52 | config.set_flag(trt.BuilderFlag.FP16) 53 | print("fp16 is supported. Setting up fp16.") 54 | else: 55 | print("fp16 is not supported. Using the fp32 instead.") 56 | else: 57 | raise ValueError("Unsupported type. Only the following types are supported: " + 58 | "fp32, fp16, int8.") 59 | 60 | with open(onnx_file_path, "rb") as onnx_file: 61 | if not parser.parse(onnx_file.read()): 62 | print("ERROR: Failed to parse the ONNX file.") 63 | for error in range(parser.num_errors): 64 | print(parser.get_error(error)) 65 | return None 66 | 67 | return builder.build_engine(network, config) 68 | 69 | img_size = 384 70 | image = preprocess_image('./demo_material/napoleon.jpg', img_size) 71 | 72 | # generate optimized graph 73 | file_already_exist = os.path.isfile(args.engine_path) 74 | if file_already_exist: 75 | print("Engine File:" + str(args.engine_path) + " already exists, loading engine instead of ONNX file.") 76 | with open(args.engine_path, "rb") as f: 77 | engine_data = f.read() 78 | engine = trt_runtime.deserialize_cuda_engine(engine_data) 79 | else: 80 | print("Building TensorRT engine from ONNX") 81 | if args.data_type == 'fp32': 82 | engine = build_engine_onnx(args.onnx_path, args.data_type) 83 | elif args.data_type == 'fp16': 84 | engine = build_engine_onnx(args.onnx_path + '_fp16', args.data_type) 85 | with open(args.engine_path, "wb") as f: 86 | f.write(engine.serialize()) 87 | print("Engine written in: " + str(args.engine_path)) 88 | 89 | print("Finished Building.") 90 | # engine = build_engine('./trt_fp.engine') 91 | context = engine.create_execution_context() 92 | print("Created execution context.") 93 | 94 | print("Testing first image on TensorRT") 95 | batch_size = 1 96 | inputs, outputs, bindings, stream = [], [], [], cuda.Stream() 97 | for binding in engine: 98 | size = trt.volume(engine.get_binding_shape(binding)) * batch_size 99 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 100 | # Allocate host and device buffers 101 | host_mem = cuda.pagelocked_empty(size, dtype) 102 | device_mem = cuda.mem_alloc(host_mem.nbytes) 103 | # Append the device buffer to device bindings. 104 | bindings.append(int(device_mem)) 105 | # Append to the appropriate list. 106 | print("Binding type: " + str(dtype) + " is input: " + str(engine.binding_is_input(binding))) 107 | if engine.binding_is_input(binding): 108 | inputs.append({'host': host_mem, 'device': device_mem}) 109 | else: 110 | outputs.append({'host': host_mem, 'device': device_mem}) 111 | 112 | # Masking creation - - - - - 113 | enc_exp_list = [32, 64, 128, 256, 512] 114 | dec_exp = 16 115 | num_heads = 8 116 | enc_mask = create_pad_mask(mask_size=(batch_size, sum(enc_exp_list), NUM_FEATURES), 117 | pad_row=[0], pad_column=[0]).contiguous() 118 | no_peak_mask = create_no_peak_and_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS), 119 | num_pads=[0]).contiguous() 120 | cross_mask = create_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, NUM_FEATURES), 121 | pad_row=[0], pad_column=[0]).contiguous() 122 | cross_mask = 1 - cross_mask 123 | 124 | fw_dec_mask = no_peak_mask.unsqueeze(2).expand(batch_size, MAX_DECODE_STEPS, 125 | dec_exp, MAX_DECODE_STEPS).contiguous(). \ 126 | view(batch_size, MAX_DECODE_STEPS * dec_exp, MAX_DECODE_STEPS) 127 | 128 | bw_dec_mask = no_peak_mask.unsqueeze(-1).expand(batch_size, 129 | MAX_DECODE_STEPS, MAX_DECODE_STEPS, 130 | dec_exp).contiguous(). \ 131 | view(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS * dec_exp) 132 | 133 | atten_mask = cross_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 134 | 135 | # - - - - - - - - - - 136 | 137 | # Set input values 138 | if args.data_type == 'fp32': 139 | inputs[0]['host'] = np.ravel(image).astype(np.float32) 140 | inputs[1]['host'] = np.array([sos_idx]).astype(np.int32) 141 | inputs[2]['host'] = np.array(enc_mask).astype(np.int32) 142 | inputs[3]['host'] = np.array(fw_dec_mask).astype(np.int32) 143 | inputs[4]['host'] = np.array(bw_dec_mask).astype(np.int32) 144 | inputs[5]['host'] = np.array(atten_mask).astype(np.int32) 145 | elif args.data_type == 'fp16': 146 | inputs[0]['host'] = np.ravel(image).astype(np.float16) 147 | inputs[1]['host'] = np.array([sos_idx]).astype(np.int32) 148 | inputs[2]['host'] = np.array(enc_mask).astype(np.int32) 149 | inputs[3]['host'] = np.array(fw_dec_mask).astype(np.float16) 150 | inputs[4]['host'] = np.array(bw_dec_mask).astype(np.float16) 151 | inputs[5]['host'] = np.array(atten_mask).astype(np.int32) 152 | 153 | # Transfer input data to the GPU. 154 | for inp in inputs: 155 | cuda.memcpy_htod_async(inp['device'], inp['host'], stream) 156 | # Execute model 157 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) 158 | # Transfer predictions back from the GPU. 159 | for out in outputs: 160 | cuda.memcpy_dtoh_async(out['host'], out['device'], stream) 161 | # Synchronize the stream 162 | stream.synchronize() 163 | print(outputs[0]['host'].tolist()) 164 | output_caption = tokens2description(outputs[0]['host'].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx) 165 | output_probs = outputs[1]['host'].tolist() 166 | print(output_caption) 167 | print(output_probs) 168 | -------------------------------------------------------------------------------- /optims/radam.py: -------------------------------------------------------------------------------- 1 | # Credits to Liyuan Liu 2 | # https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py 3 | 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer, required 7 | 8 | 9 | class RAdam(Optimizer): 10 | 11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 12 | if not 0.0 <= lr: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if not 0.0 <= eps: 15 | raise ValueError("Invalid epsilon value: {}".format(eps)) 16 | if not 0.0 <= betas[0] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 18 | if not 0.0 <= betas[1] < 1.0: 19 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 20 | 21 | self.degenerated_to_sgd = degenerated_to_sgd 22 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 23 | for param in params: 24 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 25 | param['buffer'] = [[None, None, None] for _ in range(10)] 26 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 27 | buffer=[[None, None, None] for _ in range(10)]) 28 | super(RAdam, self).__init__(params, defaults) 29 | 30 | def __setstate__(self, state): 31 | super(RAdam, self).__setstate__(state) 32 | 33 | def step(self, closure=None): 34 | 35 | loss = None 36 | if closure is not None: 37 | loss = closure() 38 | 39 | for group in self.param_groups: 40 | 41 | for p in group['params']: 42 | if p.grad is None: 43 | continue 44 | grad = p.grad.data.float() 45 | if grad.is_sparse: 46 | raise RuntimeError('RAdam does not support sparse gradients') 47 | 48 | p_data_fp32 = p.data.float() 49 | 50 | state = self.state[p] 51 | 52 | if len(state) == 0: 53 | state['step'] = 0 54 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 55 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 56 | else: 57 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 58 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad) 64 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad) 65 | 66 | state['step'] += 1 67 | buffered = group['buffer'][int(state['step'] % 10)] 68 | if state['step'] == buffered[0]: 69 | N_sma, step_size = buffered[1], buffered[2] 70 | else: 71 | buffered[0] = state['step'] 72 | beta2_t = beta2 ** state['step'] 73 | N_sma_max = 2 / (1 - beta2) - 1 74 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 75 | buffered[1] = N_sma 76 | 77 | # more conservative since it's an approximated value 78 | if N_sma >= 5: 79 | step_size = math.sqrt( 80 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 81 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 82 | elif self.degenerated_to_sgd: 83 | step_size = 1.0 / (1 - beta1 ** state['step']) 84 | else: 85 | step_size = -1 86 | buffered[2] = step_size 87 | 88 | # more conservative since it's an approximated value 89 | if N_sma >= 5: 90 | if group['weight_decay'] != 0: 91 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32) 92 | denom = exp_avg_sq.sqrt().add_(group['eps']) 93 | p_data_fp32.addcdiv_(value=-step_size * group['lr'], tensor1=exp_avg, tensor2=denom) 94 | p.data.copy_(p_data_fp32) 95 | elif step_size > 0: 96 | if group['weight_decay'] != 0: 97 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32) 98 | p_data_fp32.add_(alpha=-step_size * group['lr'], other=exp_avg) 99 | p.data.copy_(p_data_fp32) 100 | 101 | return loss 102 | 103 | 104 | class PlainRAdam(Optimizer): 105 | 106 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 107 | if not 0.0 <= lr: 108 | raise ValueError("Invalid learning rate: {}".format(lr)) 109 | if not 0.0 <= eps: 110 | raise ValueError("Invalid epsilon value: {}".format(eps)) 111 | if not 0.0 <= betas[0] < 1.0: 112 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 113 | if not 0.0 <= betas[1] < 1.0: 114 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 115 | 116 | self.degenerated_to_sgd = degenerated_to_sgd 117 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 118 | 119 | super(PlainRAdam, self).__init__(params, defaults) 120 | 121 | def __setstate__(self, state): 122 | super(PlainRAdam, self).__setstate__(state) 123 | 124 | def step(self, closure=None): 125 | 126 | loss = None 127 | if closure is not None: 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | 132 | for p in group['params']: 133 | if p.grad is None: 134 | continue 135 | grad = p.grad.data.float() 136 | if grad.is_sparse: 137 | raise RuntimeError('RAdam does not support sparse gradients') 138 | 139 | p_data_fp32 = p.data.float() 140 | 141 | state = self.state[p] 142 | 143 | if len(state) == 0: 144 | state['step'] = 0 145 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 146 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 147 | else: 148 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 149 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 150 | 151 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 152 | beta1, beta2 = group['betas'] 153 | 154 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad) 155 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad) 156 | 157 | state['step'] += 1 158 | beta2_t = beta2 ** state['step'] 159 | N_sma_max = 2 / (1 - beta2) - 1 160 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 161 | 162 | # more conservative since it's an approximated value 163 | if N_sma >= 5: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32) 166 | step_size = group['lr'] * math.sqrt( 167 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 168 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 169 | denom = exp_avg_sq.sqrt().add_(group['eps']) 170 | p_data_fp32.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom) 171 | p.data.copy_(p_data_fp32) 172 | elif self.degenerated_to_sgd: 173 | if group['weight_decay'] != 0: 174 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32) 175 | step_size = group['lr'] / (1 - beta1 ** state['step']) 176 | p_data_fp32.add_(alpha=-step_size, other=exp_avg) 177 | p.data.copy_(p_data_fp32) 178 | 179 | return loss 180 | 181 | 182 | class AdamW(Optimizer): 183 | 184 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 185 | if not 0.0 <= lr: 186 | raise ValueError("Invalid learning rate: {}".format(lr)) 187 | if not 0.0 <= eps: 188 | raise ValueError("Invalid epsilon value: {}".format(eps)) 189 | if not 0.0 <= betas[0] < 1.0: 190 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 191 | if not 0.0 <= betas[1] < 1.0: 192 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 193 | 194 | defaults = dict(lr=lr, betas=betas, eps=eps, 195 | weight_decay=weight_decay, warmup=warmup) 196 | super(AdamW, self).__init__(params, defaults) 197 | 198 | def __setstate__(self, state): 199 | super(AdamW, self).__setstate__(state) 200 | 201 | def step(self, closure=None): 202 | loss = None 203 | if closure is not None: 204 | loss = closure() 205 | 206 | for group in self.param_groups: 207 | 208 | for p in group['params']: 209 | if p.grad is None: 210 | continue 211 | grad = p.grad.data.float() 212 | if grad.is_sparse: 213 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 214 | 215 | p_data_fp32 = p.data.float() 216 | 217 | state = self.state[p] 218 | 219 | if len(state) == 0: 220 | state['step'] = 0 221 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 222 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 223 | else: 224 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 225 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 226 | 227 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 228 | beta1, beta2 = group['betas'] 229 | 230 | state['step'] += 1 231 | 232 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad) 233 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad) 234 | 235 | denom = exp_avg_sq.sqrt().add_(group['eps']) 236 | bias_correction1 = 1 - beta1 ** state['step'] 237 | bias_correction2 = 1 - beta2 ** state['step'] 238 | 239 | if group['warmup'] > state['step']: 240 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 241 | else: 242 | scheduled_lr = group['lr'] 243 | 244 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 245 | 246 | if group['weight_decay'] != 0: 247 | p_data_fp32.add_(alpha=-group['weight_decay'] * scheduled_lr, other=p_data_fp32) 248 | 249 | p_data_fp32.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom) 250 | 251 | p.data.copy_(p_data_fp32) 252 | 253 | return loss 254 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # NOTE: Installing whatever version of torch, torchvision, h5py, Pillow 2 | # fit your machine should work in most cases. Here is reported one possible 3 | # working configuration but these specific versions are not mandatory. 4 | 5 | h5py==3.8.0 6 | numpy==1.21.6 7 | Pillow==9.0.1 8 | Pillow==10.2.0 9 | torch==1.12.1+cu113 10 | torchvision==0.10.0+cu111 11 | -------------------------------------------------------------------------------- /requirements_wTensorRT.txt: -------------------------------------------------------------------------------- 1 | # NOTE: Installing whatever version of torch, torchvision, h5py, Pillow 2 | # fit your machine should work in most cases. Here is reported one possible 3 | # working configuration but these specific versions are not mandatory. 4 | 5 | h5py==3.8.0 6 | numpy==1.21.6 7 | onnx==1.14.0 8 | onnxruntime==1.14.1 9 | onnxsim==0.4.17 10 | Pillow==9.0.1 11 | Pillow==10.2.0 12 | pycuda==2022.1 13 | tensorrt==8.0.1.6 14 | torch==1.12.1+cu113 15 | torchvision==0.10.0+cu111 16 | -------------------------------------------------------------------------------- /results_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/results_image.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import math 4 | import torch 5 | import argparse 6 | from argparse import Namespace 7 | from utils.args_utils import str2list, str2bool 8 | import copy 9 | from time import time 10 | 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | from models.ensemble_captioning_model import EsembleCaptioningModel 16 | from data.coco_dataloader import CocoDataLoader 17 | from data.coco_dataset import CocoDatasetKarpathy 18 | from utils import language_utils 19 | from utils.language_utils import compute_num_pads as compute_num_pads 20 | from eval.eval import COCOEvalCap 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore", category=UserWarning) 24 | 25 | import functools 26 | print = functools.partial(print, flush=True) 27 | 28 | 29 | def convert_time_as_hhmmss(ticks): 30 | return str(int(ticks / 60)) + " m " + \ 31 | str(int(ticks) % 60) + " s" 32 | 33 | 34 | def compute_evaluation_loss(loss_function, 35 | model, 36 | data_set, 37 | data_loader, 38 | num_samples, 39 | sub_batch_size, 40 | dataset_split, 41 | rank=0, 42 | verbose=False): 43 | model.eval() 44 | 45 | sb_size = sub_batch_size 46 | 47 | tot_loss = 0 48 | num_sub_batch = math.ceil(num_samples / sb_size) 49 | tot_num_tokens = 0 50 | for sb_it in range(num_sub_batch): 51 | from_idx = sb_it * sb_size 52 | to_idx = min((sb_it + 1) * sb_size, num_samples) 53 | 54 | sub_batch_input_x, sub_batch_target_y, sub_batch_input_x_num_pads, sub_batch_target_y_num_pads, \ 55 | = data_loader.get_batch_samples(img_idx_batch_list=list(range(from_idx, to_idx)), 56 | dataset_split=dataset_split) 57 | sub_batch_input_x = sub_batch_input_x.to(rank) 58 | sub_batch_target_y = sub_batch_target_y.to(rank) 59 | 60 | sub_batch_input_x = sub_batch_input_x 61 | sub_batch_target_y = sub_batch_target_y 62 | tot_num_tokens += sub_batch_target_y.size(1)*sub_batch_target_y.size(0) - \ 63 | sum(sub_batch_target_y_num_pads) 64 | pred = model(enc_x=sub_batch_input_x, 65 | dec_x=sub_batch_target_y[:, :-1], 66 | enc_x_num_pads=sub_batch_input_x_num_pads, 67 | dec_x_num_pads=sub_batch_target_y_num_pads, 68 | apply_softmax=False) 69 | tot_loss += loss_function(pred, sub_batch_target_y[:, 1:], 70 | data_set.get_pad_token_idx(), 71 | divide_by_non_zeros=False).item() 72 | del sub_batch_input_x, sub_batch_target_y, pred 73 | torch.cuda.empty_cache() 74 | tot_loss /= tot_num_tokens 75 | if verbose and rank == 0: 76 | print("Validation Loss on " + str(num_samples) + " samples: " + str(tot_loss)) 77 | 78 | return tot_loss 79 | 80 | 81 | def evaluate_model(ddp_model, 82 | y_idx2word_list, 83 | beam_size, max_seq_len, 84 | sos_idx, eos_idx, 85 | rank, ddp_sync_port, 86 | parallel_batches=16, 87 | 88 | indexes=[0], 89 | data_loader=None, 90 | dataset_split=CocoDatasetKarpathy.TrainSet_ID, 91 | use_images_instead_of_features=False, 92 | 93 | verbose=True, 94 | stanford_model_path="./eval/get_stanford_models.sh"): 95 | 96 | start_time = time() 97 | 98 | sub_list_predictions = [] 99 | validate_y = [] 100 | num_samples = len(indexes) 101 | 102 | ddp_model.eval() 103 | with torch.no_grad(): 104 | sb_size = parallel_batches 105 | num_iter_sub_batches = math.ceil(len(indexes) / sb_size) 106 | for sb_it in range(num_iter_sub_batches): 107 | last_iter = sb_it == num_iter_sub_batches - 1 108 | if last_iter: 109 | from_idx = sb_it * sb_size 110 | to_idx = num_samples 111 | else: 112 | from_idx = sb_it * sb_size 113 | to_idx = (sb_it + 1) * sb_size 114 | 115 | if use_images_instead_of_features: 116 | sub_batch_x = [data_loader.get_images_by_idx(i, dataset_split=dataset_split, transf_mode='test').unsqueeze(0) 117 | for i in list(range(from_idx, to_idx))] 118 | sub_batch_x = torch.cat(sub_batch_x).to(rank) 119 | sub_batch_x_num_pads = [0] * sub_batch_x.size(0) 120 | else: 121 | sub_batch_x = [data_loader.get_vis_features_by_idx(i, dataset_split=dataset_split) 122 | for i in list(range(from_idx, to_idx))] 123 | sub_batch_x = torch.nn.utils.rnn.pad_sequence(sub_batch_x, batch_first=True).to(rank) 124 | sub_batch_x_num_pads = compute_num_pads(sub_batch_x) 125 | 126 | validate_y += [data_loader.get_all_image_captions_by_idx(i, dataset_split=dataset_split) \ 127 | for i in list(range(from_idx, to_idx))] 128 | 129 | beam_search_kwargs = {'beam_size': beam_size, 130 | 'beam_max_seq_len': max_seq_len, 131 | 'sample_or_max': 'max', 132 | 'how_many_outputs': 1, 133 | 'sos_idx': sos_idx, 134 | 'eos_idx': eos_idx} 135 | 136 | output_words, _ = ddp_model(enc_x=sub_batch_x, 137 | enc_x_num_pads=sub_batch_x_num_pads, 138 | mode='beam_search', **beam_search_kwargs) 139 | 140 | output_words = [output_words[i][0] for i in range(len(output_words))] 141 | 142 | pred_sentence = language_utils.convert_allsentences_idx2word(output_words, y_idx2word_list) 143 | for sentence in pred_sentence: 144 | sub_list_predictions.append(' '.join(sentence[1:-1])) # remove EOS and SOS 145 | 146 | del sub_batch_x, sub_batch_x_num_pads, output_words 147 | 148 | ddp_model.train() 149 | 150 | if rank == 0 and verbose: 151 | # dirty code to leave the evaluation code untouched 152 | list_predictions = [sub_predictions for sub_predictions in sub_list_predictions] 153 | list_list_references = [[validate_y[i][j] for j in range(len(validate_y[i]))] for i in range(len(validate_y))] 154 | 155 | gts_dict = dict() 156 | for i in range(len(list_list_references)): 157 | gts_dict[i] = [{u'image_id': i, u'caption': list_list_references[i][j]} 158 | for j in range(len(list_list_references[i]))] 159 | 160 | pred_dict = dict() 161 | for i in range(len(list_predictions)): 162 | pred_dict[i] = [{u'image_id': i, u'caption': list_predictions[i]}] 163 | 164 | coco_eval = COCOEvalCap(gts_dict, pred_dict, list(range(len(list_predictions))), 165 | get_stanford_models_path=stanford_model_path) 166 | score_results = coco_eval.evaluate(bleu=True, rouge=True, cider=True, spice=True, meteor=True, verbose=False) 167 | elapsed_ticks = time() - start_time 168 | print("Evaluation Phase over " + str(len(validate_y)) + " BeamSize: " + str(beam_size) + 169 | " elapsed: " + str(int(elapsed_ticks/60)) + " m " + str(int(elapsed_ticks % 60)) + ' s') 170 | print(score_results) 171 | 172 | if rank == 0: 173 | return pred_dict, gts_dict 174 | 175 | return None, None 176 | 177 | 178 | def evaluate_model_on_set(ddp_model, 179 | caption_idx2word_list, 180 | sos_idx, eos_idx, 181 | num_samples, 182 | data_loader, 183 | dataset_split, 184 | eval_max_len, 185 | rank, ddp_sync_port, 186 | parallel_batches=16, 187 | beam_sizes=[1], 188 | stanford_model_path='./eval/get_stanford_models.sh', 189 | use_images_instead_of_features=False, 190 | get_predictions=False): 191 | 192 | with torch.no_grad(): 193 | ddp_model.eval() 194 | 195 | for beam in beam_sizes: 196 | pred_dict, gts_dict = evaluate_model(ddp_model, 197 | y_idx2word_list=caption_idx2word_list, 198 | beam_size=beam, max_seq_len=eval_max_len, 199 | sos_idx=sos_idx, eos_idx=eos_idx, 200 | rank=rank, 201 | ddp_sync_port=ddp_sync_port, 202 | parallel_batches=parallel_batches, 203 | indexes=list(range(num_samples)), 204 | data_loader=data_loader, 205 | dataset_split=dataset_split, 206 | use_images_instead_of_features=use_images_instead_of_features, 207 | verbose=True, 208 | stanford_model_path=stanford_model_path) 209 | 210 | if rank == 0 and get_predictions: 211 | return pred_dict, gts_dict 212 | 213 | return None, None 214 | 215 | 216 | def get_ensemble_model(reference_model, 217 | checkpoints_paths, 218 | rank=0): 219 | model_list = [] 220 | for i in range(len(checkpoints_paths)): 221 | model = copy.deepcopy(reference_model) 222 | model.to(rank) 223 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} 224 | checkpoint = torch.load(checkpoints_paths[i], 225 | map_location=map_location) 226 | model.load_state_dict(checkpoint['model_state_dict']) 227 | model_list.append(model) 228 | 229 | model = EsembleCaptioningModel(model_list, rank).to(rank) 230 | ddp_model = DDP(model, device_ids=[rank]) 231 | return ddp_model 232 | 233 | 234 | def test(rank, world_size, 235 | is_end_to_end, 236 | model_args, 237 | is_ensemble, 238 | coco_dataset, 239 | eval_parallel_batch_size, 240 | eval_beam_sizes, 241 | show_predictions, 242 | array_of_init_seeds, 243 | model_max_len, 244 | save_model_path, 245 | ddp_sync_port): 246 | print("GPU: " + str(rank) + "] Process " + str(rank) + " working...") 247 | 248 | os.environ['MASTER_ADDR'] = 'localhost' 249 | os.environ['MASTER_PORT'] = ddp_sync_port 250 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 251 | 252 | img_size = 384 253 | if is_end_to_end: 254 | from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 255 | model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, 256 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], 257 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, 258 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.1, 259 | swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, 260 | swin_use_checkpoint=False, 261 | final_swin_dim=1536, 262 | 263 | d_model=model_args.model_dim, N_enc=model_args.N_enc, 264 | N_dec=model_args.N_dec, num_heads=8, ff=2048, 265 | num_exp_enc_list=[32, 64, 128, 256, 512], 266 | num_exp_dec=16, 267 | output_word2idx=coco_dataset.caption_word2idx_dict, 268 | output_idx2word=coco_dataset.caption_idx2word_list, 269 | max_seq_len=model_max_len, drop_args=model_args.drop_args, 270 | rank=rank) 271 | else: 272 | from models.ExpansionNet_v2 import ExpansionNet_v2 273 | model = ExpansionNet_v2(d_model=model_args.model_dim, N_enc=model_args.N_enc, 274 | N_dec=model_args.N_dec, num_heads=8, ff=2048, 275 | num_exp_enc_list=[32, 64, 128, 256, 512], 276 | num_exp_dec=16, 277 | output_word2idx=coco_dataset.caption_word2idx_dict, 278 | output_idx2word=coco_dataset.caption_idx2word_list, 279 | max_seq_len=model_max_len, drop_args=model_args.drop_args, 280 | img_feature_dim=1536, 281 | rank=rank) 282 | 283 | model.to(rank) 284 | ddp_model = DDP(model, device_ids=[rank]) 285 | 286 | data_loader = CocoDataLoader(coco_dataset=coco_dataset, 287 | batch_size=1, 288 | num_procs=world_size, 289 | array_of_init_seeds=array_of_init_seeds, 290 | dataloader_mode='image_wise', 291 | resize_image_size=img_size if is_end_to_end else None, 292 | rank=rank, 293 | verbose=False) 294 | 295 | if not is_ensemble: 296 | print("Not ensemble") 297 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} 298 | checkpoint = torch.load(save_model_path, map_location=map_location) 299 | model.load_state_dict(checkpoint['model_state_dict'], strict=is_end_to_end) 300 | else: 301 | print("Ensembling Evaluation") 302 | list_checkpoints = os.listdir(save_model_path) 303 | checkpoints_list = [save_model_path + elem for elem in list_checkpoints if elem.endswith('.pth')] 304 | print("Detected checkpoints: " + str(checkpoints_list)) 305 | 306 | if len(checkpoints_list) == 0: 307 | print("No checkpoints found") 308 | dist.destroy_process_group() 309 | exit(-1) 310 | ddp_model = get_ensemble_model(model, checkpoints_list, rank=rank) 311 | 312 | print("Evaluation on Validation Set") 313 | evaluate_model_on_set(ddp_model, coco_dataset.caption_idx2word_list, 314 | coco_dataset.get_sos_token_idx(), coco_dataset.get_eos_token_idx(), 315 | coco_dataset.val_num_images, 316 | data_loader, 317 | CocoDatasetKarpathy.ValidationSet_ID, model_max_len, 318 | rank, ddp_sync_port, 319 | parallel_batches=eval_parallel_batch_size, 320 | use_images_instead_of_features=is_end_to_end, 321 | beam_sizes=eval_beam_sizes) 322 | 323 | print("Evaluation on Test Set") 324 | pred_dict, gts_dict = evaluate_model_on_set(ddp_model, coco_dataset.caption_idx2word_list, 325 | coco_dataset.get_sos_token_idx(), coco_dataset.get_eos_token_idx(), 326 | coco_dataset.test_num_images, 327 | data_loader, 328 | CocoDatasetKarpathy.TestSet_ID, model_max_len, 329 | rank, ddp_sync_port, 330 | parallel_batches=eval_parallel_batch_size, 331 | use_images_instead_of_features=is_end_to_end, 332 | beam_sizes=eval_beam_sizes, 333 | get_predictions=show_predictions) 334 | 335 | if rank == 0 and show_predictions: 336 | with open("predictions.txt", 'w') as f: 337 | for i in range(len(pred_dict)): 338 | prediction = pred_dict[i][0]['caption'] 339 | ground_truth_list = [gts_dict[i][j]['caption'] for j in range(len(gts_dict[i]))] 340 | f.write(str(i) + '----------------------------------------------------------------------' + '\n') 341 | f.write('Pred: ' + str(prediction) + '\n') 342 | f.write('Gt: ' + str(ground_truth_list) + '\n') 343 | 344 | print("[GPU: " + str(rank) + " ] Closing...") 345 | dist.destroy_process_group() 346 | 347 | 348 | def spawn_train_processes(is_end_to_end, 349 | model_args, 350 | is_ensemble, 351 | coco_dataset, 352 | eval_parallel_batch_size, 353 | eval_beam_sizes, 354 | show_predictions, 355 | num_gpus, 356 | ddp_sync_port, 357 | save_model_path 358 | ): 359 | 360 | max_sequence_length = coco_dataset.max_seq_len + 20 361 | print("Max sequence length: " + str(max_sequence_length)) 362 | print("y vocabulary size: " + str(len(coco_dataset.caption_word2idx_dict))) 363 | 364 | world_size = torch.cuda.device_count() 365 | print("Using - ", world_size, " processes / GPUs!") 366 | assert(num_gpus <= world_size), "requested num gpus higher than the number of available gpus " 367 | print("Requested num GPUs: " + str(num_gpus)) 368 | 369 | array_of_init_seeds = [random.random() for _ in range(10)] 370 | mp.spawn(test, 371 | args=(num_gpus, 372 | is_end_to_end, 373 | model_args, 374 | is_ensemble, 375 | coco_dataset, 376 | eval_parallel_batch_size, 377 | eval_beam_sizes, 378 | show_predictions, 379 | array_of_init_seeds, 380 | max_sequence_length, 381 | save_model_path, 382 | ddp_sync_port), 383 | nprocs=num_gpus, 384 | join=True) 385 | 386 | 387 | if __name__ == "__main__": 388 | 389 | parser = argparse.ArgumentParser(description='Image Captioning') 390 | parser.add_argument('--model_dim', type=int, default=512) 391 | parser.add_argument('--N_enc', type=int, default=3) 392 | parser.add_argument('--N_dec', type=int, default=3) 393 | parser.add_argument('--show_predictions', type=str2bool, default=False) 394 | 395 | parser.add_argument('--is_end_to_end', type=str2bool, default=True) 396 | parser.add_argument('--is_ensemble', type=str2bool, default=False) 397 | 398 | parser.add_argument('--num_gpus', type=int, default=1) 399 | parser.add_argument('--ddp_sync_port', type=int, default=12354) 400 | parser.add_argument('--save_model_path', type=str, default='./github_ignore_material/saves/') 401 | 402 | parser.add_argument('--eval_parallel_batch_size', type=int, default=16) 403 | parser.add_argument('--eval_beam_sizes', type=str2list, default=[3]) 404 | 405 | parser.add_argument('--images_path', type=str, default="./github_ignore_material/raw_data/") 406 | parser.add_argument('--preproc_images_hdf5_filepath', type=str, default=None) 407 | parser.add_argument('--features_path', type=str, default='./github_ignore_material/raw_data/') 408 | parser.add_argument('--captions_path', type=str, default='./github_ignore_material/raw_data/') 409 | 410 | args = parser.parse_args() 411 | args.ddp_sync_port = str(args.ddp_sync_port) 412 | 413 | assert (args.eval_parallel_batch_size % args.num_gpus == 0), \ 414 | "num gpus must be multiple of the requested parallel batch size" 415 | 416 | print("is_ensemble: " + str(args.is_ensemble)) 417 | print("eval parallel batch_size: " + str(args.eval_parallel_batch_size)) 418 | print("ddp_sync_port: " + str(args.ddp_sync_port)) 419 | print("save_model_path: " + str(args.save_model_path)) 420 | 421 | drop_args = Namespace(enc=0.0, 422 | dec=0.0, 423 | enc_input=0.0, 424 | dec_input=0.0, 425 | other=0.0) 426 | 427 | model_args = Namespace(model_dim=args.model_dim, 428 | N_enc=args.N_enc, 429 | N_dec=args.N_dec, 430 | dropout=0.0, 431 | drop_args=drop_args 432 | ) 433 | 434 | coco_dataset = CocoDatasetKarpathy( 435 | images_path=args.images_path, 436 | coco_annotations_path=args.captions_path + "dataset_coco.json", 437 | preproc_images_hdf5_filepath=args.preproc_images_hdf5_filepath if args.is_end_to_end else None, 438 | precalc_features_hdf5_filepath=None if args.is_end_to_end else args.features_path, 439 | limited_num_train_images=None, 440 | limited_num_val_images=5000) 441 | 442 | spawn_train_processes(is_end_to_end=args.is_end_to_end, 443 | model_args=model_args, 444 | is_ensemble=args.is_ensemble, 445 | coco_dataset=coco_dataset, 446 | eval_parallel_batch_size=args.eval_parallel_batch_size, 447 | eval_beam_sizes=args.eval_beam_sizes, 448 | show_predictions=args.show_predictions, 449 | num_gpus=args.num_gpus, 450 | ddp_sync_port=args.ddp_sync_port, 451 | save_model_path=args.save_model_path 452 | ) 453 | -------------------------------------------------------------------------------- /utils/args_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | # thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 5 | def str2bool(v): 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def str2list(v): 17 | if '[' in v and ']' in v: 18 | return list(map(int, v.strip('[]').split(','))) 19 | else: 20 | raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]') 21 | 22 | 23 | def str2type(v): 24 | if v.lower() == 'fp32' or v.lower() == 'fp16': 25 | return v.lower() 26 | else: 27 | raise argparse.ArgumentTypeError('Invalid type, currently supported type: [fp32, fp16]') 28 | 29 | 30 | 31 | def scheduler_type_choice(v): 32 | if v == 'annealing' or v == 'custom_warmup_anneal': 33 | return v 34 | else: 35 | raise argparse.ArgumentTypeError('Argument must be either ' 36 | '\'annealing\', ' 37 | '\'custom_warmup_anneal\'') 38 | 39 | 40 | def optim_type_choice(v): 41 | if v == 'adam' or v == 'radam': 42 | return v 43 | else: 44 | raise argparse.ArgumentTypeError('Argument must be either \'adam\', ' 45 | '\'radam\'.') -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision 3 | from PIL import Image as PIL_Image 4 | 5 | 6 | def preprocess_image(image_path, img_size): 7 | transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size))]) 8 | transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 9 | std=[0.229, 0.224, 0.225])]) 10 | 11 | pil_image = PIL_Image.open(image_path) 12 | if pil_image.mode != 'RGB': 13 | pil_image = PIL_Image.new("RGB", pil_image.size) 14 | preprocess_pil_image = transf_1(pil_image) 15 | image = torchvision.transforms.ToTensor()(preprocess_pil_image) 16 | image = transf_2(image) 17 | return image.unsqueeze(0) 18 | -------------------------------------------------------------------------------- /utils/language_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def compute_num_pads(list_vis_features): 5 | max_len = -1 6 | for vis_features in list_vis_features: 7 | num_vis_features = len(vis_features) 8 | if num_vis_features > max_len: 9 | max_len = num_vis_features 10 | num_pad_vector = [] 11 | for vis_features in list_vis_features: 12 | num_pad_vector.append(max_len - len(vis_features)) 13 | return num_pad_vector 14 | 15 | 16 | def remove_punctuations(sentences): 17 | punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"] 18 | res_sentences_list = [] 19 | for i in range(len(sentences)): 20 | res_sentence = [] 21 | for word in sentences[i].split(' '): 22 | if word not in punctuations: 23 | res_sentence.append(word) 24 | res_sentences_list.append(' '.join(res_sentence)) 25 | return res_sentences_list 26 | 27 | 28 | def lowercase_and_clean_trailing_spaces(sentences): 29 | return [(sentences[i].lower()).rstrip() for i in range(len(sentences))] 30 | 31 | 32 | def add_space_between_non_alphanumeric_symbols(sentences): 33 | return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))] 34 | 35 | 36 | def tokenize(list_sentences): 37 | res_sentences_list = [] 38 | for i in range(len(list_sentences)): 39 | sentence = list_sentences[i].split(' ') 40 | while '' in sentence: 41 | sentence.remove('') 42 | res_sentences_list.append(sentence) 43 | return res_sentences_list 44 | 45 | 46 | def convert_vector_word2idx(sentence, word2idx_dict): 47 | return [word2idx_dict[word] for word in sentence] 48 | 49 | 50 | def convert_allsentences_word2idx(sentences, word2idx_dict): 51 | return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))] 52 | 53 | 54 | def convert_vector_idx2word(sentence, idx2word_list): 55 | return [idx2word_list[idx] for idx in sentence] 56 | 57 | 58 | def convert_allsentences_idx2word(sentences, idx2word_list): 59 | return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))] 60 | 61 | 62 | def tokens2description(tokens, idx2word_list, sos_idx, eos_idx): 63 | desc = [] 64 | for tok in tokens: 65 | if tok == sos_idx: 66 | continue 67 | if tok == eos_idx: 68 | break 69 | desc.append(tok) 70 | desc = convert_vector_idx2word(desc, idx2word_list) 71 | desc[-1] = desc[-1] + '.' 72 | pred = ' '.join(desc).capitalize() 73 | return pred 74 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_pad_mask(mask_size, pad_row, pad_column, rank=0): 5 | batch_size, output_seq_len, input_seq_len = mask_size 6 | mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank) 7 | 8 | for batch_idx in range(batch_size): 9 | mask[batch_idx, :, (input_seq_len - pad_column[batch_idx]):] = 0 10 | mask[batch_idx, (output_seq_len - pad_row[batch_idx]):, :] = 0 11 | return mask 12 | 13 | 14 | def create_no_peak_and_pad_mask(mask_size, num_pads, rank=0): 15 | batch_size, seq_len, seq_len = mask_size 16 | mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8), 17 | diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank) 18 | for batch_idx in range(batch_size): 19 | mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0 20 | mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0 21 | return mask 22 | -------------------------------------------------------------------------------- /utils/saving_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from datetime import datetime 5 | 6 | from torch.nn.parameter import Parameter 7 | 8 | def load_most_recent_checkpoint(model, 9 | optimizer=None, 10 | scheduler=None, 11 | data_loader=None, 12 | rank=0, 13 | save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S', 14 | verbose=True): 15 | ls_files = os.listdir(save_model_path) 16 | most_recent_checkpoint_datetime = None 17 | most_recent_checkpoint_filename = None 18 | most_recent_checkpoint_info = 'no_additional_info' 19 | for file_name in ls_files: 20 | if file_name.startswith('checkpoint_'): 21 | _, datetime_str, _, info, _ = file_name.split('_') 22 | file_datetime = datetime.strptime(datetime_str, datetime_format) 23 | if (most_recent_checkpoint_datetime is None) or \ 24 | (most_recent_checkpoint_datetime is not None and 25 | file_datetime > most_recent_checkpoint_datetime): 26 | most_recent_checkpoint_datetime = file_datetime 27 | most_recent_checkpoint_filename = file_name 28 | most_recent_checkpoint_info = info 29 | 30 | if most_recent_checkpoint_filename is not None: 31 | if verbose: 32 | print("Loading: " + str(save_model_path + most_recent_checkpoint_filename)) 33 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} 34 | checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename, 35 | map_location=map_location) 36 | model.load_state_dict(checkpoint['model_state_dict']) 37 | if optimizer is not None: 38 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 39 | if scheduler is not None: 40 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 41 | if data_loader is not None: 42 | data_loader.load_state(checkpoint['data_loader_state_dict']) 43 | return True, most_recent_checkpoint_info 44 | else: 45 | if verbose: 46 | print("Loading: no checkpoint found in " + str(save_model_path)) 47 | return False, most_recent_checkpoint_info 48 | 49 | 50 | def save_last_checkpoint(model, 51 | optimizer, 52 | scheduler, 53 | data_loader, 54 | save_model_path='./', 55 | num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S', 56 | additional_info='noinfo', 57 | verbose=True): 58 | 59 | checkpoint = { 60 | 'model_state_dict': model.state_dict(), 61 | 'optimizer_state_dict': optimizer.state_dict(), 62 | 'scheduler_state_dict': scheduler.state_dict(), 63 | 'data_loader_state_dict': data_loader.save_state(), 64 | } 65 | 66 | ls_files = os.listdir(save_model_path) 67 | oldest_checkpoint_datetime = None 68 | oldest_checkpoint_filename = None 69 | num_check_points = 0 70 | for file_name in ls_files: 71 | if file_name.startswith('checkpoint_'): 72 | num_check_points += 1 73 | _, datetime_str, _, _, _ = file_name.split('_') 74 | file_datetime = datetime.strptime(datetime_str, datetime_format) 75 | if (oldest_checkpoint_datetime is None) or \ 76 | (oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime): 77 | oldest_checkpoint_datetime = file_datetime 78 | oldest_checkpoint_filename = file_name 79 | 80 | if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints: 81 | os.remove(save_model_path + oldest_checkpoint_filename) 82 | 83 | new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \ 84 | '_epoch' + str(data_loader.get_epoch_it()) + \ 85 | 'it' + str(data_loader.get_batch_it()) + \ 86 | 'bs' + str(data_loader.get_batch_size()) + \ 87 | '_' + str(additional_info) + '_.pth' 88 | if verbose: 89 | print("Saved to " + str(new_checkpoint_filename)) 90 | torch.save(checkpoint, save_model_path + new_checkpoint_filename) 91 | 92 | 93 | def partially_load_state_dict(model, state_dict, verbose=False, max_num_print=5): 94 | own_state = model.state_dict() 95 | max_num_print = max_num_print 96 | count_print = 0 97 | for name, param in state_dict.items(): 98 | if name not in own_state: 99 | if verbose: 100 | print("Not found: " + str(name)) 101 | continue 102 | if isinstance(param, Parameter): 103 | param = param.data 104 | own_state[name].copy_(param) 105 | if verbose: 106 | if count_print < max_num_print: 107 | print("Found: " + str(name)) 108 | count_print += 1 109 | 110 | --------------------------------------------------------------------------------