├── .gitignore
├── README.md
├── data
├── __init__.py
├── coco_dataloader.py
├── coco_dataset.py
└── transparent_data_loader.py
├── data_generator.py
├── demo.py
├── demo_material
├── __init__.py
├── cat_girl.jpg
├── demo_coco_tokens.pickle
├── micheal.jpg
├── napoleon.jpg
└── tatin.jpg
├── demo_results.png
├── eval
├── __init__.py
├── bleu
│ ├── LICENSE
│ ├── __init__.py
│ ├── bleu.py
│ └── bleu_scorer.py
├── cider
│ ├── __init__.py
│ ├── cider.py
│ ├── cider_scorer.py
│ ├── reinforce_cider.py
│ └── reinforce_cider_scorer.py
├── eval.py
├── get_stanford_models.sh
├── meteor
│ ├── __init__.py
│ ├── data
│ │ └── paraphrase-en.gz
│ ├── meteor-1.5.jar
│ └── meteor.py
├── rouge
│ ├── __init__.py
│ └── rouge.py
├── spice
│ ├── __init__.py
│ ├── cache
│ │ └── .gitkeep
│ ├── lib
│ │ ├── Meteor-1.5.jar
│ │ ├── SceneGraphParser-1.0.jar
│ │ ├── ejml-0.23.jar
│ │ ├── fst-2.47.jar
│ │ ├── guava-19.0.jar
│ │ ├── hamcrest-core-1.3.jar
│ │ ├── jackson-core-2.5.3.jar
│ │ ├── javassist-3.19.0-GA.jar
│ │ ├── json-simple-1.1.1.jar
│ │ ├── junit-4.12.jar
│ │ ├── lmdbjni-0.4.6.jar
│ │ ├── lmdbjni-linux64-0.4.6.jar
│ │ ├── lmdbjni-osx64-0.4.6.jar
│ │ ├── lmdbjni-win64-0.4.6.jar
│ │ ├── objenesis-2.4.jar
│ │ ├── slf4j-api-1.7.12.jar
│ │ └── slf4j-simple-1.7.21.jar
│ ├── spice-1.0.jar
│ ├── spice.py
│ └── tmp
│ │ └── .gitkeep
└── tokenizer
│ ├── __init__.py
│ ├── ptbtokenizer.py
│ └── stanford-corenlp-3.4.1.jar
├── github_ignore_material
├── raw_data
│ └── .gitkeep
└── saves
│ └── .gitkeep
├── license.txt
├── losses
├── __init__.py
├── loss.py
└── reward.py
├── models
├── End_ExpansionNet_v2.py
├── ExpansionNet_v2.py
├── __init__.py
├── captioning_model.py
├── ensemble_captioning_model.py
├── layers.py
└── swin_transformer_mod.py
├── onnx4tensorrt
├── End_ExpansionNet_v2_onnx_tensorrt.py
├── __init__.py
├── convert2onnx.py
├── onnx2tensorrt.py
└── swin_transformer_onnx_tensorrt.py
├── optims
└── radam.py
├── requirements.txt
├── requirements_wTensorRT.txt
├── results_image.png
├── test.py
├── train.py
└── utils
├── args_utils.py
├── image_utils.py
├── language_utils.py
├── masking.py
└── saving_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /github_ignore_material/raw_data/*
2 | !/github_ignore_material/raw_data/.gitkeep
3 | /github_ignore_material/saves/*
4 | !/github_ignore_material/saves/.gitkeep
5 | eval/spice/lib/stanford-corenlp-3.6.0.jar
6 | eval/spice/lib/stanford-corenlp-3.6.0-models.jar
7 | .idea/*
8 | **/__pycache__/*
9 | **/tmp/*
10 | !/eval/spice/tmp/.gitkeep
11 | **/cache/*
12 | !/eval/spice/cache/.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ExpansionNet v2: Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning
2 |
3 | Implementation code for "[Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning](https://www.computer.org/csdl/proceedings-article/bigdata/2023/10386812/1TUPyooQsnu)" [ [BigData2023](https://www.computer.org/csdl/proceedings-article/bigdata/2023/10386812/1TUPyooQsnu) ]
4 | [ [Arxiv](https://arxiv.org/abs/2208.06551) ], previously entitled as "ExpansionNet v2: Block Static Expansion
5 | in fast end to end training for Image Captioning".
6 |
7 | ## Demo
8 |
9 | You can test the model on generic images (not included in COCO) downloading
10 | the checkpoint [here](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing)
11 | and launching the script `demo.py`:
12 | ```
13 | python demo.py \
14 | --load_path your_download_folder/rf_model.pth \
15 | --image_paths your_image_path/image_1 your_image_path/image_2 ...
16 |
17 | ```
18 | Some examples:
19 |
20 |
21 |
22 |
23 |
24 | images are available in `demo_material`.
25 |
26 | ## Results
27 |
28 | [SacreEOS](https://github.com/jchenghu/sacreeos) Signature: `STANDARDwInit+Cider-D[n4,s6.0]+average[nspi5]+1.0.0`.
29 | Results are artifacts-free.
30 |
31 | Online evaluation server results:
32 |
33 |
34 | Captions |
35 | B1 |
36 | B2 |
37 | B3 |
38 | B4 |
39 | Meteor |
40 | Rouge-L |
41 | CIDEr-D |
42 |
43 |
44 | c40 |
45 | 96.9 |
46 | 92.6 |
47 | 85.0 |
48 | 75.3 |
49 | 40.1 |
50 | 76.4 |
51 | 140.8 |
52 |
53 |
54 | c5 |
55 | 83.3 |
56 | 68.8 |
57 | 54.4 |
58 | 42.1 |
59 | 30.4 |
60 | 60.8 |
61 | 138.5 |
62 |
63 |
64 |
65 | Results on the Karpathy test split:
66 |
67 |
68 | Model |
69 | B@1 |
70 | B@4 |
71 | Meteor |
72 | Rouge-L |
73 | CIDEr-D |
74 | Spice |
75 |
76 |
77 | Ensemble |
78 | 83.5 |
79 | 42.7 |
80 | 30.6 |
81 | 61.1 |
82 | 143.7 |
83 | 24.7 |
84 |
85 |
86 | Single |
87 | 82.8 |
88 | 41.5 |
89 | 30.3 |
90 | 60.5 |
91 | 140.4 |
92 | 24.5 |
93 |
94 |
95 |
96 | Predictions examples:
97 |
98 |
99 |
100 |
101 |
102 |
103 | ## ONNX & TensorRT
104 |
105 | The model supports now ONNX conversion and deployment with TensorRT.
106 | The graph can be generated using `onnx4tensorrt/convert2onnx.py`.
107 | Its execution mainly requires the `onnx` package but the `onnx_runtime` and `onnx_tensorrt` packages are
108 | optionally used for testing purposes (see `convert2onnx.py` arguments).
109 |
110 | Assuming Generic conversion commands:
111 | ```
112 | python onnx4tensorrt/convert2onnx.py --onnx_simplify true --load_model_path &> output_onnx.txt &
113 | python onnx4tensorrt/onnx2tensorrt.py &> output_tensorrt.txt &
114 | ```
115 | Currently working only in FP32.
116 |
117 |
118 | ## Training
119 |
120 | In this guide we cover all the training steps reported in the paper and
121 | provide the commands to reproduce our work.
122 |
123 | #### Requirements
124 |
125 | * python >= 3.7
126 | * numpy
127 | * Java 1.8.0
128 | * torch
129 | * torchvision
130 | * h5py
131 |
132 | Installing whatever version of `torch, torchvision, h5py, Pillow` fit your machine should
133 | work in most cases.
134 |
135 | One instance of requirements file can be found in `requirements.txt`, in case also TensorRT is needed
136 | use `requirements_wTensorRT.txt`. However they represent one working instance, specific versions of each package
137 | might not be required.
138 |
139 | #### Data preparation
140 |
141 | MS-COCO 2014 images can be downloaded [here](https://cocodataset.org/#download),
142 | the respective captions are uploaded in our online [drive](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing)
143 | and the backbone can be found [here](https://github.com/microsoft/Swin-Transformer). All files, in particular
144 | the `dataset_coco.json` file and the backbone are suggested to be moved in `github_ignore_materal/raw_data/` since commands provided
145 | in the following steps assume these files are placed in that directory.
146 |
147 |
148 | #### Premises
149 |
150 | For the sake of transparency (at the cost of possibly being overly verbose)
151 | the complete commands are shown below, but only few arguments deserve a little
152 | bit of care for the reproduction of our work while most of them are automatically handled.
153 |
154 | Logs are stored in `output_file.txt`, which is continuously updated
155 | until the process is complete (in Linux it may be handy the command `watch -n 1 tail -n 30 output_file.txt`). It is overwritten in each training phase, thus,
156 | before moving to the next one, make sure to save or make a copy if needed.
157 |
158 | Lastly, in some configurations the batch size may look different compared to the one
159 | reported in the paper when argument `num_accum` is specified (default is 1). This is only
160 | a visual subtlety, which means that gradient accumulation is performed in order to satisfy
161 | the memory constraints of 40GB RAM of a single GPU.
162 |
163 | #### 1. Cross Entropy Training: Features generation
164 |
165 | First we generate the features for the first training step:
166 | ```
167 | cd ExpansionNet_v2_src
168 | python data_generator.py \
169 | --save_model_path ./github_ignore_material/raw_data/swin_large_patch4_window12_384_22k.pth \
170 | --output_path ./github_ignore_material/raw_data/features.hdf5 \
171 | --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ \
172 | --captions_path ./github_ignore_material/raw_data/ &> output_file.txt &
173 | ```
174 | Even if it's suggested not to do so, the `output_path` argument can be replaced with the desired destination (this would require
175 | changing the argument `features_path` in the next commands as well). Since it's a pretty big
176 | file (102GB), once the first training is completed, it will be automatically overwritten by
177 | the remaining operations in case the default name is unchanged.
178 |
179 | TIPS: if 100GB of memory is too much for your disk, add the option `--dtype fp16`
180 | which saves arrays into FP16 so it requires only 50GB. It shouldn't change affect much the result.
181 | By default, we keep FP32 for conformity with the experimental setup of the paper.
182 |
183 |
184 | #### 2. Cross-Entropy Training: Partial Training
185 |
186 | In this step the model is trained using the Cross Entropy loss and the features generated
187 | in the previous step:
188 | ```
189 | python train.py --N_enc 3 --N_dec 3 \
190 | --model_dim 512 --seed 775533 --optim_type radam --sched_type custom_warmup_anneal \
191 | --warmup 10000 --lr 2e-4 --anneal_coeff 0.8 --anneal_every_epoch 2 --enc_drop 0.3 \
192 | --dec_drop 0.3 --enc_input_drop 0.3 --dec_input_drop 0.3 --drop_other 0.3 \
193 | --batch_size 48 --num_accum 1 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [3] \
194 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \
195 | --is_end_to_end False --features_path ./github_ignore_material/raw_data/features.hdf5 --partial_load False \
196 | --print_every_iter 11807 --eval_every_iter 999999 \
197 | --reinforce False --num_epochs 8 &> output_file.txt &
198 | ```
199 |
200 | #### 3. Cross-Entropy Training: End to End Training
201 |
202 | The following command trains the entire network in the end to end mode. However,
203 | one argument need to be changed according to the previous result, the
204 | checkpoint name file. Weights are stored in the directory `github_ignore_materal/saves/`,
205 | with the prefix `checkpoint_ ... _xe.pth` we will refer it as `phase2_checkpoint` below and in
206 | the later step:
207 | ```
208 | python train.py --N_enc 3 --N_dec 3 \
209 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \
210 | --warmup 1 --lr 3e-5 --anneal_coeff 0.55 --anneal_every_epoch 1 --enc_drop 0.3 \
211 | --dec_drop 0.3 --enc_input_drop 0.3 --dec_input_drop 0.3 --drop_other 0.3 \
212 | --batch_size 16 --num_accum 3 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [3] \
213 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \
214 | --is_end_to_end True --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ --partial_load True \
215 | --backbone_save_path ./github_ignore_material/raw_data/swin_large_patch4_window12_384_22k.pth \
216 | --body_save_path ./github_ignore_material/saves/phase2_checkpoint \
217 | --print_every_iter 15000 --eval_every_iter 999999 \
218 | --reinforce False --num_epochs 2 &> output_file.txt &
219 | ```
220 | In case you are interested in the network's weights at the end of this stage,
221 | before moving to the self-critical learning, rename the checkpoint file from `checkpoint_ ... _xe.pth` into something
222 | else like `phase3_checkpoint` (make sure to change the prefix) otherwise it will
223 | be overwritten during step 5.
224 |
225 | #### 4. CIDEr optimization: Features generation
226 |
227 | This step generates the features for the reinforcement step:
228 | ```
229 | python data_generator.py \
230 | --save_model_path ./github_ignore_material/saves/phase3_checkpoint \
231 | --output_path ./github_ignore_material/raw_data/features.hdf5 \
232 | --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ \
233 | --captions_path ./github_ignore_material/raw_data/ &> output_file.txt &
234 | ```
235 |
236 | #### 5. CIDEr optimization: Partial Training
237 |
238 | The following command performs the partial training using the self-critical learning:
239 | ```
240 | python train.py --N_enc 3 --N_dec 3 \
241 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \
242 | --warmup 1 --lr 1e-4 --anneal_coeff 0.8 --anneal_every_epoch 1 --enc_drop 0.1 \
243 | --dec_drop 0.1 --enc_input_drop 0.1 --dec_input_drop 0.1 --drop_other 0.1 \
244 | --batch_size 24 --num_accum 2 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [5] \
245 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \
246 | --is_end_to_end False --partial_load True \
247 | --features_path ./github_ignore_material/raw_data/features.hdf5 \
248 | --body_save_path ./github_ignore_material/saves/phase3_checkpoint.pth \
249 | --print_every_iter 4000 --eval_every_iter 99999 \
250 | --reinforce True --num_epochs 9 &> output_file.txt &
251 | ```
252 | We refer to the last checkpoint produced in this step as `phase5_checkpoint`,
253 | it should already achieve around 139.5 CIDEr-D on both Validaton and Test set, however
254 | it can be still improved by a little margin with the following optional step.
255 |
256 |
257 | #### 6. CIDEr optimization: End to End Training
258 |
259 | This last step again train the model in an end to end fashion, however it is optional since it only slightly improves the performances:
260 | ```
261 | python train.py --N_enc 3 --N_dec 3 \
262 | --model_dim 512 --optim_type radam --seed 775533 --sched_type custom_warmup_anneal \
263 | --warmup 1 --anneal_coeff 1.0 --lr 2e-6 --enc_drop 0.1 \
264 | --dec_drop 0.1 --enc_input_drop 0.1 --dec_input_drop 0.1 --drop_other 0.1 \
265 | --batch_size 24 --num_accum 2 --num_gpus 1 --ddp_sync_port 11317 --eval_beam_sizes [5] \
266 | --save_path ./github_ignore_material/saves/ --save_every_minutes 60 --how_many_checkpoints 1 \
267 | --is_end_to_end True --images_path ./github_ignore_material/raw_data/MS_COCO_2014/ --partial_load True \
268 | --backbone_save_path ./github_ignore_material/raw_data/phase3_checkpoint \
269 | --body_save_path ./github_ignore_material/saves/phase5_checkpoint \
270 | --print_every_iter 15000 --eval_every_iter 999999 \
271 | --reinforce True --num_epochs 1 &> output_file.txt &
272 | ```
273 |
274 | ## Evaluation
275 |
276 | In this section we provide the evaluation scripts. We refer to the
277 | last checkpoint as `phase6_checkpoint`. In case the previous training
278 | procedures have been skipped,
279 | weights of one of the ensemble's model can be found [here](https://drive.google.com/drive/folders/1bBMH4-Fw1LcQZmSzkMCqpEl0piIP88Y3?usp=sharing).
280 | ```
281 | python test.py --N_enc 3 --N_dec 3 --model_dim 512 \
282 | --num_gpus 1 --eval_beam_sizes [5] --is_end_to_end True \
283 | --eval_parallel_batch_size 4 \
284 | --images_path ./github_ignore_material/raw_data/ \
285 | --save_model_path ./github_ignore_material/saves/phase6_checkpoint
286 | ```
287 | The option `is_end_to_end` can be toggled according to the model's type.
288 | It might be required to give permissions to the file `./eval/get_stanford_models.sh` (e.g. `chmod a+x -R ./eval/` in Linux).
289 |
290 |
291 | ## Citation
292 |
293 | If you find this repository useful, please consider citing (no obligation):
294 |
295 | ```
296 | @inproceedings{hu2023exploiting,
297 | title={Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning},
298 | author={Hu, Jia Cheng and Cavicchioli, Roberto and Capotondi, Alessandro},
299 | booktitle={2023 IEEE International Conference on Big Data (BigData)},
300 | pages={2173--2182},
301 | year={2023},
302 | organization={IEEE Computer Society}
303 | }
304 | ```
305 |
306 | ## Acknowledgements
307 |
308 | We thank the PyTorch team and the following repositories:
309 | * https://github.com/microsoft/Swin-Transformer
310 | * https://github.com/ruotianluo/ImageCaptioning.pytorch
311 | * https://github.com/tylin/coco-caption
312 |
313 | special thanks to the work of [Yiyu Wang et al](https://arxiv.org/abs/2203.15350).
314 |
315 | We thank the user [@shahizat](https://github.com/shahizat) for the suggestion of ONNX and TensorRT conversions.
316 | We also thank the github users from the Issues section which provided valuable feedbacks, suggestions,
317 | and even found very insidious bugs.
318 |
319 |
320 |
321 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/data/__init__.py
--------------------------------------------------------------------------------
/data/coco_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from time import time
3 | from utils import language_utils
4 |
5 | import functools
6 | print = functools.partial(print, flush=True)
7 |
8 |
9 | class CocoDatasetKarpathy:
10 |
11 | TrainSet_ID = 1
12 | ValidationSet_ID = 2
13 | TestSet_ID = 3
14 |
15 | def __init__(self,
16 | images_path,
17 | coco_annotations_path,
18 | precalc_features_hdf5_filepath,
19 | preproc_images_hdf5_filepath=None,
20 | limited_num_train_images=None,
21 | limited_num_val_images=None,
22 | limited_num_test_images=None,
23 | dict_min_occurrences=5,
24 | verbose=True
25 | ):
26 | super(CocoDatasetKarpathy, self).__init__()
27 |
28 | self.use_images_instead_of_features = False
29 | if precalc_features_hdf5_filepath is None or precalc_features_hdf5_filepath == 'None' or \
30 | precalc_features_hdf5_filepath == 'none' or precalc_features_hdf5_filepath == '':
31 | self.use_images_instead_of_features = True
32 | print("Warning: since no hdf5 path is provided using images instead of pre-calculated features.")
33 | print("Features path: " + str(precalc_features_hdf5_filepath))
34 |
35 | self.preproc_images_hdf5_filepath = None
36 | if preproc_images_hdf5_filepath is not None:
37 | print("Preprocessed hdf5 file path not None: " + str(preproc_images_hdf5_filepath))
38 | print("Using preprocessed hdf5 file instead.")
39 | self.preproc_images_hdf5_filepath = preproc_images_hdf5_filepath
40 |
41 | else:
42 | self.precalc_features_hdf5_filepath = precalc_features_hdf5_filepath
43 | print("Features path: " + str(self.precalc_features_hdf5_filepath))
44 | print("Features path provided, images are provided in form of features.")
45 |
46 | if images_path is None:
47 | self.images_path = ""
48 | else:
49 | self.images_path = images_path
50 |
51 | self.karpathy_train_dict = dict()
52 | self.karpathy_val_dict = dict()
53 | self.karpathy_test_dict = dict()
54 |
55 | with open(coco_annotations_path, 'r') as f:
56 | json_file = json.load(f)['images']
57 |
58 | if verbose:
59 | print("Initializing dataset... ", end=" ")
60 | for json_item in json_file:
61 | new_item = dict()
62 |
63 | new_item['img_path'] = self.images_path + json_item['filepath'] + '/img/' + json_item['filename']
64 |
65 | new_item_captions = [item['raw'] for item in json_item['sentences']]
66 | new_item['img_id'] = json_item['cocoid']
67 | new_item['captions'] = new_item_captions
68 |
69 | if json_item['split'] == 'train' or json_item['split'] == 'restval':
70 | self.karpathy_train_dict[json_item['cocoid']] = new_item
71 | elif json_item['split'] == 'test':
72 | self.karpathy_test_dict[json_item['cocoid']] = new_item
73 | elif json_item['split'] == 'val':
74 | self.karpathy_val_dict[json_item['cocoid']] = new_item
75 |
76 | self.karpathy_train_list = []
77 | self.karpathy_val_list = []
78 | self.karpathy_test_list = []
79 | for key in self.karpathy_train_dict.keys():
80 | self.karpathy_train_list.append(self.karpathy_train_dict[key])
81 | for key in self.karpathy_val_dict.keys():
82 | self.karpathy_val_list.append(self.karpathy_val_dict[key])
83 | for key in self.karpathy_test_dict.keys():
84 | self.karpathy_test_list.append(self.karpathy_test_dict[key])
85 |
86 | self.train_num_images = len(self.karpathy_train_list)
87 | self.val_num_images = len(self.karpathy_val_list)
88 | self.test_num_images = len(self.karpathy_test_list)
89 |
90 | if limited_num_train_images is not None:
91 | self.karpathy_train_list = self.karpathy_train_list[:limited_num_train_images]
92 | self.train_num_images = limited_num_train_images
93 | if limited_num_val_images is not None:
94 | self.karpathy_val_list = self.karpathy_val_list[:limited_num_val_images]
95 | self.val_num_images = limited_num_val_images
96 | if limited_num_test_images is not None:
97 | self.karpathy_test_list = self.karpathy_test_list[:limited_num_test_images]
98 | self.test_num_images = limited_num_test_images
99 |
100 | if verbose:
101 | print("Num train images: " + str(self.train_num_images))
102 | print("Num val images: " + str(self.val_num_images))
103 | print("Num test images: " + str(self.test_num_images))
104 |
105 | tokenized_captions_list = []
106 | for i in range(self.train_num_images):
107 | for caption in self.karpathy_train_list[i]['captions']:
108 | tmp = language_utils.lowercase_and_clean_trailing_spaces([caption])
109 | tmp = language_utils.add_space_between_non_alphanumeric_symbols(tmp)
110 | tmp = language_utils.remove_punctuations(tmp)
111 | tokenized_caption = ['SOS'] + language_utils.tokenize(tmp)[0] + ['EOS']
112 | tokenized_captions_list.append(tokenized_caption)
113 |
114 | counter_dict = dict()
115 | for i in range(len(tokenized_captions_list)):
116 | for word in tokenized_captions_list[i]:
117 | if word not in counter_dict:
118 | counter_dict[word] = 1
119 | else:
120 | counter_dict[word] += 1
121 |
122 | less_than_min_occurrences_set = set()
123 | for k, v in counter_dict.items():
124 | if v < dict_min_occurrences:
125 | less_than_min_occurrences_set.add(k)
126 | if verbose:
127 | print("tot tokens " + str(len(counter_dict)) +
128 | " less than " + str(dict_min_occurrences) + ": " + str(len(less_than_min_occurrences_set)) +
129 | " remaining: " + str(len(counter_dict) - len(less_than_min_occurrences_set)))
130 |
131 | self.num_caption_vocab = 4
132 | self.max_seq_len = 0
133 | discovered_words = ['PAD', 'SOS', 'EOS', 'UNK']
134 | for i in range(len(tokenized_captions_list)):
135 | caption = tokenized_captions_list[i]
136 | if len(caption) > self.max_seq_len:
137 | self.max_seq_len = len(caption)
138 | for word in caption:
139 | if (word not in discovered_words) and (not word in less_than_min_occurrences_set):
140 | discovered_words.append(word)
141 | self.num_caption_vocab += 1
142 |
143 | discovered_words.sort()
144 | self.caption_word2idx_dict = dict()
145 | self.caption_idx2word_list = []
146 | for i in range(len(discovered_words)):
147 | self.caption_word2idx_dict[discovered_words[i]] = i
148 | self.caption_idx2word_list.append(discovered_words[i])
149 | if verbose:
150 | print("There are " + str(self.num_caption_vocab) + " vocabs in dict")
151 |
152 | def get_image_path(self, img_idx, dataset_split):
153 |
154 | if dataset_split == CocoDatasetKarpathy.TestSet_ID:
155 | img_path = self.karpathy_test_list[img_idx]['img_path']
156 | img_id = self.karpathy_test_list[img_idx]['img_id']
157 | elif dataset_split == CocoDatasetKarpathy.ValidationSet_ID:
158 | img_path = self.karpathy_val_list[img_idx]['img_path']
159 | img_id = self.karpathy_val_list[img_idx]['img_id']
160 | else:
161 | img_path = self.karpathy_train_list[img_idx]['img_path']
162 | img_id = self.karpathy_train_list[img_idx]['img_id']
163 |
164 | return img_path, img_id
165 |
166 | def get_all_images_captions(self, dataset_split):
167 | all_image_references = []
168 |
169 | if dataset_split == CocoDatasetKarpathy.TestSet_ID:
170 | dataset = self.karpathy_test_list
171 | elif dataset_split == CocoDatasetKarpathy.ValidationSet_ID:
172 | dataset = self.karpathy_val_list
173 | else:
174 | dataset = self.karpathy_train_list
175 |
176 | for img_idx in range(len(dataset)):
177 | all_image_references.append(dataset[img_idx]['captions'])
178 | return all_image_references
179 |
180 | def get_eos_token_idx(self):
181 | return self.caption_word2idx_dict['EOS']
182 |
183 | def get_sos_token_idx(self):
184 | return self.caption_word2idx_dict['SOS']
185 |
186 | def get_pad_token_idx(self):
187 | return self.caption_word2idx_dict['PAD']
188 |
189 | def get_unk_token_idx(self):
190 | return self.caption_word2idx_dict['UNK']
191 |
192 | def get_eos_token_str(self):
193 | return 'EOS'
194 |
195 | def get_sos_token_str(self):
196 | return 'SOS'
197 |
198 | def get_pad_token_str(self):
199 | return 'PAD'
200 |
201 | def get_unk_token_str(self):
202 | return 'UNK'
203 |
--------------------------------------------------------------------------------
/data/transparent_data_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | This dataloader inherently represents also a training session that can be saved and loaded
3 | """
4 |
5 |
6 | class TransparentDataLoader:
7 | def __init__(self):
8 | super(TransparentDataLoader, self).__init__()
9 |
10 | # initialize the training on a specific epoch
11 | def init_epoch(self, epoch, batch_size):
12 | raise NotImplementedError
13 |
14 | def get_next_batch(self):
15 | raise NotImplementedError
16 |
17 | # methods for progress saving and loading progress of the data loader
18 | def set_epoch_it(self, epoch):
19 | raise NotImplementedError
20 |
21 | def get_epoch_it(self):
22 | raise NotImplementedError
23 |
24 | def get_num_epoch(self):
25 | raise NotImplementedError
26 |
27 | def get_num_batches(self):
28 | raise NotImplementedError
29 |
30 | def set_batch_it(self, batch_it):
31 | raise NotImplementedError
32 |
33 | def get_batch_it(self):
34 | raise NotImplementedError
35 |
36 | def get_batch_size(self):
37 | raise NotImplementedError
38 |
39 | def save_state(self):
40 | raise NotImplementedError
41 |
42 | def load_state(self, state_dict):
43 | raise NotImplementedError
44 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from PIL import Image as PIL_Image
4 | import torchvision
5 | import torch
6 | import argparse
7 | from argparse import Namespace
8 | from torch.nn.parameter import Parameter
9 | from time import time
10 |
11 | from data.coco_dataset import CocoDatasetKarpathy
12 |
13 | torch.autograd.set_detect_anomaly(False)
14 | torch.set_num_threads(1)
15 | torch.set_num_interop_threads(1)
16 | import functools
17 | print = functools.partial(print, flush=True)
18 | DEFAULT_RANK = 0
19 |
20 |
21 | def convert_time_as_hhmmss(ticks):
22 | return str(int(ticks / 60)) + " m " + \
23 | str(int(ticks) % 60) + " s"
24 |
25 |
26 | def generate_data(path_args):
27 |
28 | coco_dataset = CocoDatasetKarpathy(images_path=path_args.images_path,
29 | coco_annotations_path=args.captions_path + "dataset_coco.json",
30 | preproc_images_hdf5_filepath=None,
31 | precalc_features_hdf5_filepath=None,
32 | limited_num_train_images=None,
33 | limited_num_val_images=5000)
34 |
35 | from models.swin_transformer_mod import SwinTransformer
36 | model = SwinTransformer(
37 | img_size=384, patch_size=4, in_chans=3,
38 | embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48],
39 | window_size=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
40 | drop_rate=0.0, attn_drop_rate=0.0,
41 | drop_path_rate=0.0,
42 | norm_layer=torch.nn.LayerNorm, ape=False, patch_norm=True,
43 | use_checkpoint=False)
44 |
45 | def load_backbone_only_from_save(model, state_dict, prefix=None):
46 | own_state = model.state_dict()
47 | for name, param in state_dict.items():
48 | if prefix is not None and name.startswith(prefix):
49 | name = name[len(prefix):]
50 | if name not in own_state:
51 | print("Not found: " + str(name))
52 | continue
53 | if isinstance(param, Parameter):
54 | param = param.data
55 | own_state[name].copy_(param)
56 | print("Found: " + str(name))
57 |
58 | save_model_path = path_args.save_model_path
59 | map_location = {'cuda:%d' % DEFAULT_RANK: 'cuda:%d' % DEFAULT_RANK}
60 | checkpoint = torch.load(save_model_path, map_location=map_location)
61 | if 'model_state_dict' in checkpoint.keys():
62 | print("Custom save point found")
63 | load_backbone_only_from_save(model, checkpoint['model_state_dict'], prefix='swin_transf.')
64 | else:
65 | print("Custom save point not found")
66 | load_backbone_only_from_save(model, checkpoint['model'], prefix=None)
67 | print("Loading phase ended")
68 |
69 | model = model.to(DEFAULT_RANK)
70 |
71 | test_preprocess_layers_1 = [torchvision.transforms.Resize((384, 384))]
72 | test_preprocess_layers_2 = [torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
73 | test_preprocess_1 = torchvision.transforms.Compose(test_preprocess_layers_1)
74 | test_preprocess_2 = torchvision.transforms.Compose(test_preprocess_layers_2)
75 |
76 | model.eval()
77 | with torch.no_grad():
78 | """
79 | TIP: if you don't have 100GB of memory is too much for features allocation,
80 | try saving arrays into FP16. It shouldn't change affect much the result.
81 | Replace each line into:
82 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()))
83 | into:
84 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16)
85 | we kept it FP32 for coherence with the experimental setup of the paper.
86 | """
87 | hdf5_file = h5py.File(path_args.output_path, 'w')
88 |
89 | def apply_model(model, file_path):
90 | pil_image = PIL_Image.open(file_path)
91 | if pil_image.mode != 'RGB':
92 | pil_image = PIL_Image.new("RGB", pil_image.size)
93 | preprocess_pil_image = test_preprocess_1(pil_image)
94 | tens_image = torchvision.transforms.ToTensor()(preprocess_pil_image)
95 | tens_image = test_preprocess_2(tens_image).to(DEFAULT_RANK)
96 | output = model(tens_image.unsqueeze(0))
97 | return output.squeeze(0)
98 |
99 | for i in range(coco_dataset.train_num_images):
100 | img_path, img_id = coco_dataset.get_image_path(coco_dataset.train_num_images - i - 1,
101 | CocoDatasetKarpathy.TrainSet_ID)
102 | output = apply_model(model, img_path)
103 | if path_args.dtype == 'fp16':
104 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16)
105 | else:
106 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()))
107 | if (i+1) % 5000 == 0 or (i+1) == coco_dataset.train_num_images:
108 | print("Train " + str(i+1) + " / " + str(coco_dataset.train_num_images) + " completed")
109 |
110 | for i in range(coco_dataset.test_num_images):
111 | img_path, img_id = coco_dataset.get_image_path(i, CocoDatasetKarpathy.TestSet_ID)
112 | output = apply_model(model, img_path)
113 | if path_args.dtype == 'fp16':
114 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16)
115 | else:
116 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()))
117 | if (i+1) % 2500 == 0 or (i+1) == coco_dataset.test_num_images:
118 | print("Test " + str(i+1) + " / " + str(coco_dataset.test_num_images) + " completed")
119 |
120 | for i in range(coco_dataset.val_num_images):
121 | img_path, img_id = coco_dataset.get_image_path(i, CocoDatasetKarpathy.ValidationSet_ID)
122 | output = apply_model(model, img_path)
123 | if path_args.dtype == 'fp16':
124 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()), dtype=np.float16)
125 | else:
126 | hdf5_file.create_dataset(str(img_id) + '_features', data=np.array(output.cpu()))
127 | if (i+1) % 2500 == 0 or (i+1) == coco_dataset.test_num_images:
128 | print("Val " + str(i+1) + " / " + str(coco_dataset.test_num_images) + " completed")
129 |
130 | print("[GPU: " + str(DEFAULT_RANK) + " ] Closing...")
131 |
132 |
133 | if __name__ == "__main__":
134 |
135 | parser = argparse.ArgumentParser(description='Image Captioning')
136 | parser.add_argument('--save_model_path', type=str, default='./github_ignore_material/saves/')
137 | parser.add_argument('--output_path', type=str, default='./github_ignore_material/raw_data/precalc_features.hdf5')
138 | parser.add_argument('--images_path', type=str, default='/tmp/images/')
139 | parser.add_argument('--captions_path', type=str, default='./github_ignore_material/raw_data/')
140 | parser.add_argument('--dtype', type=str, default='fp32', help='Decide data type of saved features')
141 |
142 | args = parser.parse_args()
143 |
144 | path_args = Namespace(save_model_path=args.save_model_path,
145 | output_path=args.output_path,
146 | images_path=args.images_path,
147 | captions_path=args.captions_path,
148 | dtype=args.dtype)
149 | generate_data(path_args=path_args)
150 |
151 |
152 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Small script for testing on few generic images given the model weights.
3 | In order to minimize the requirements, it runs only on CPU and images are
4 | processed one by one.
5 | """
6 |
7 | import torch
8 | import argparse
9 | import pickle
10 | from argparse import Namespace
11 |
12 | from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
13 | from utils.image_utils import preprocess_image
14 | from utils.language_utils import tokens2description
15 |
16 |
17 | if __name__ == "__main__":
18 |
19 | parser = argparse.ArgumentParser(description='Demo')
20 | parser.add_argument('--model_dim', type=int, default=512)
21 | parser.add_argument('--N_enc', type=int, default=3)
22 | parser.add_argument('--N_dec', type=int, default=3)
23 | parser.add_argument('--max_seq_len', type=int, default=74)
24 | parser.add_argument('--device', type=str, default='cpu')
25 | parser.add_argument('--load_path', type=str, default='./rf_model.pth')
26 | parser.add_argument('--image_paths', type=str,
27 | default=['./demo_material/tatin.jpg',
28 | './demo_material/micheal.jpg',
29 | './demo_material/napoleon.jpg',
30 | './demo_material/cat_girl.jpg'],
31 | nargs='+')
32 | parser.add_argument('--beam_size', type=int, default=5)
33 |
34 | args = parser.parse_args()
35 |
36 | drop_args = Namespace(enc=0.0,
37 | dec=0.0,
38 | enc_input=0.0,
39 | dec_input=0.0,
40 | other=0.0)
41 | model_args = Namespace(model_dim=args.model_dim,
42 | N_enc=args.N_enc,
43 | N_dec=args.N_dec,
44 | drop_args=drop_args)
45 |
46 | with open('./demo_material/demo_coco_tokens.pickle', 'rb') as f:
47 | coco_tokens = pickle.load(f)
48 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']]
49 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']]
50 |
51 | print("Dictionary loaded ...")
52 |
53 | img_size = 384
54 | model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
55 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
56 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
57 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0,
58 | swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True,
59 | swin_use_checkpoint=False,
60 | final_swin_dim=1536,
61 |
62 | d_model=model_args.model_dim, N_enc=model_args.N_enc,
63 | N_dec=model_args.N_dec, num_heads=8, ff=2048,
64 | num_exp_enc_list=[32, 64, 128, 256, 512],
65 | num_exp_dec=16,
66 | output_word2idx=coco_tokens['word2idx_dict'],
67 | output_idx2word=coco_tokens['idx2word_list'],
68 | max_seq_len=args.max_seq_len, drop_args=model_args.drop_args,
69 | rank=args.device)
70 | checkpoint = torch.load(args.load_path, map_location=torch.device(args.device))
71 | model.load_state_dict(checkpoint['model_state_dict'])
72 | print("Model loaded ...")
73 |
74 | input_images = []
75 | for path in args.image_paths:
76 | input_images.append(preprocess_image(path, img_size))
77 |
78 | print("Generating captions ...\n")
79 | for i in range(len(input_images)):
80 | path = args.image_paths[i]
81 | image = input_images[i]
82 | beam_search_kwargs = {'beam_size': args.beam_size,
83 | 'beam_max_seq_len': args.max_seq_len,
84 | 'sample_or_max': 'max',
85 | 'how_many_outputs': 1,
86 | 'sos_idx': sos_idx,
87 | 'eos_idx': eos_idx}
88 | with torch.no_grad():
89 | pred, _ = model(enc_x=image,
90 | enc_x_num_pads=[0],
91 | mode='beam_search', **beam_search_kwargs)
92 | pred = tokens2description(pred[0][0], coco_tokens['idx2word_list'], sos_idx, eos_idx)
93 | print(path + ' \n\tDescription: ' + pred + '\n')
94 |
95 | print("Closed.")
96 |
--------------------------------------------------------------------------------
/demo_material/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/__init__.py
--------------------------------------------------------------------------------
/demo_material/cat_girl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/cat_girl.jpg
--------------------------------------------------------------------------------
/demo_material/demo_coco_tokens.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/demo_coco_tokens.pickle
--------------------------------------------------------------------------------
/demo_material/micheal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/micheal.jpg
--------------------------------------------------------------------------------
/demo_material/napoleon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/napoleon.jpg
--------------------------------------------------------------------------------
/demo_material/tatin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_material/tatin.jpg
--------------------------------------------------------------------------------
/demo_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/demo_results.png
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/eval/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/eval/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/eval/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from eval.bleu.bleu_scorer import BleuScorer
12 |
13 | class Bleu:
14 | def __init__(self, n=4):
15 | # default compute Blue score up to 4
16 | self._n = n
17 | self._hypo_for_image = {}
18 | self.ref_for_image = {}
19 |
20 | def compute_score(self, gts, res):
21 |
22 | assert(gts.keys() == res.keys())
23 | imgIds = gts.keys()
24 |
25 | bleu_scorer = BleuScorer(n=self._n)
26 | for id in imgIds:
27 | hypo = res[id]
28 | ref = gts[id]
29 |
30 | # Sanity check.
31 | assert(type(hypo) is list)
32 | assert(len(hypo) == 1)
33 | assert(type(ref) is list)
34 | assert(len(ref) >= 1)
35 |
36 | bleu_scorer += (hypo[0], ref)
37 |
38 | #score, scores = bleu_scorer.compute_score(option='shortest')
39 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
40 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
41 |
42 | # return (bleu, bleu_info)
43 | return score, scores
44 |
45 | def method(self):
46 | return "Bleu"
47 |
--------------------------------------------------------------------------------
/eval/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in range(1,n+1):
30 | for i in range(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.items():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, reflen, refmaxcounts, eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 |
64 | testlen, counts = precook(test, n, True)
65 |
66 | result = {}
67 |
68 | # Calculate effective reference sentence length.
69 |
70 | if eff == "closest":
71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72 | else: ## i.e., "average" or "shortest" or None
73 | result["reflen"] = reflen
74 |
75 | result["testlen"] = testlen
76 |
77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
78 |
79 | result['correct'] = [0]*n
80 | for (ngram, count) in counts.items():
81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82 |
83 | return result
84 |
85 | class BleuScorer(object):
86 | """Bleu scorer.
87 | """
88 |
89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90 | # special_reflen is used in oracle (proportional effective ref len for a node).
91 |
92 | def copy(self):
93 | ''' copy the refs.'''
94 | new = BleuScorer(n=self.n)
95 | new.ctest = copy.copy(self.ctest)
96 | new.crefs = copy.copy(self.crefs)
97 | new._score = None
98 | return new
99 |
100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101 | ''' singular instance '''
102 |
103 | self.n = n
104 | self.crefs = []
105 | self.ctest = []
106 | self.cook_append(test, refs)
107 | self.special_reflen = special_reflen
108 |
109 | def cook_append(self, test, refs):
110 | '''called by constructor and __iadd__ to avoid creating new instances.'''
111 |
112 | if refs is not None:
113 | self.crefs.append(cook_refs(refs))
114 | if test is not None:
115 | reflen, refmaxcounts = self.crefs[-1] # python2.7 to python3.5 ADAPTATION
116 | cooked_test = cook_test(test, reflen, refmaxcounts)
117 | self.ctest.append(cooked_test) ## N.B.: -1
118 | else:
119 | self.ctest.append(None) # lens of crefs and ctest have to match
120 |
121 | self._score = None ## need to recompute
122 |
123 | def ratio(self, option=None):
124 | self.compute_score(option=option)
125 | return self._ratio
126 |
127 | def score_ratio(self, option=None):
128 | '''return (bleu, len_ratio) pair'''
129 | return (self.fscore(option=option), self.ratio(option=option))
130 |
131 | def score_ratio_str(self, option=None):
132 | return "%.4f (%.2f)" % self.score_ratio(option)
133 |
134 | def reflen(self, option=None):
135 | self.compute_score(option=option)
136 | return self._reflen
137 |
138 | def testlen(self, option=None):
139 | self.compute_score(option=option)
140 | return self._testlen
141 |
142 | def retest(self, new_test):
143 | if type(new_test) is str:
144 | new_test = [new_test]
145 | assert len(new_test) == len(self.crefs), new_test
146 | self.ctest = []
147 | for t, rs in zip(new_test, self.crefs):
148 | self.ctest.append(cook_test(t, rs))
149 | self._score = None
150 |
151 | return self
152 |
153 | def rescore(self, new_test):
154 | ''' replace test(s) with new test(s), and returns the new score.'''
155 |
156 | return self.retest(new_test).compute_score()
157 |
158 | def size(self):
159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
160 | return len(self.crefs)
161 |
162 | def __iadd__(self, other):
163 | '''add an instance (e.g., from another sentence).'''
164 |
165 | if type(other) is tuple:
166 | ## avoid creating new BleuScorer instances
167 | self.cook_append(other[0], other[1])
168 | else:
169 | assert self.compatible(other), "incompatible BLEUs."
170 | self.ctest.extend(other.ctest)
171 | self.crefs.extend(other.crefs)
172 | self._score = None ## need to recompute
173 |
174 | return self
175 |
176 | def compatible(self, other):
177 | return isinstance(other, BleuScorer) and self.n == other.n
178 |
179 | def single_reflen(self, option="average"):
180 | return self._single_reflen(self.crefs[0][0], option)
181 |
182 | def _single_reflen(self, reflens, option=None, testlen=None):
183 |
184 | if option == "shortest":
185 | reflen = min(reflens)
186 | elif option == "average":
187 | reflen = float(sum(reflens))/len(reflens)
188 | elif option == "closest":
189 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
190 | else:
191 | assert False, "unsupported reflen option %s" % option
192 |
193 | return reflen
194 |
195 | def recompute_score(self, option=None, verbose=0):
196 | self._score = None
197 | return self.compute_score(option, verbose)
198 |
199 | def compute_score(self, option=None, verbose=0):
200 | n = self.n
201 | small = 1e-9
202 | tiny = 1e-15 ## so that if guess is 0 still return 0
203 | bleu_list = [[] for _ in range(n)]
204 |
205 | if self._score is not None:
206 | return self._score
207 |
208 | if option is None:
209 | option = "average" if len(self.crefs) == 1 else "closest"
210 |
211 | self._testlen = 0
212 | self._reflen = 0
213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
214 |
215 | # for each sentence
216 | for comps in self.ctest:
217 | testlen = comps['testlen']
218 | self._testlen += testlen
219 |
220 | if self.special_reflen is None: ## need computation
221 | reflen = self._single_reflen(comps['reflen'], option, testlen)
222 | else:
223 | reflen = self.special_reflen
224 |
225 | self._reflen += reflen
226 |
227 | for key in ['guess','correct']:
228 | for k in range(n):
229 | totalcomps[key][k] += comps[key][k]
230 |
231 | # append per image bleu score
232 | bleu = 1.
233 | for k in range(n):
234 | bleu *= (float(comps['correct'][k]) + tiny) \
235 | /(float(comps['guess'][k]) + small)
236 | bleu_list[k].append(bleu ** (1./(k+1)))
237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
238 | if ratio < 1:
239 | for k in range(n):
240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
241 |
242 | if verbose > 1:
243 | print(comps, reflen)
244 |
245 | totalcomps['reflen'] = self._reflen
246 | totalcomps['testlen'] = self._testlen
247 |
248 | bleus = []
249 | bleu = 1.
250 | for k in range(n):
251 | bleu *= float(totalcomps['correct'][k] + tiny) \
252 | / (totalcomps['guess'][k] + small)
253 | bleus.append(bleu ** (1./(k+1)))
254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
255 | if ratio < 1:
256 | for k in range(n):
257 | bleus[k] *= math.exp(1 - 1/ratio)
258 |
259 | if verbose > 0:
260 | pass
261 | #print(totalcomps)
262 | #print("ratio:", ratio)
263 |
264 | self._score = bleus
265 | return self._score, bleu_list
266 |
--------------------------------------------------------------------------------
/eval/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/eval/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from eval.cider.cider_scorer import CiderScorer
11 |
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(gts.keys() == res.keys())
33 | imgIds = gts.keys()
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
--------------------------------------------------------------------------------
/eval/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in range(1,n+1):
23 | for i in range(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.items():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].items():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/eval/cider/reinforce_cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | # ReinforceCIDEr is an alternative implementation of CIDEr where the corpus is initialized in the constructor
11 | # so it doesn't need to be processed again every time we need to compute the cider score
12 | # in the Self Critical Learning Process. --- Jia Cheng
13 |
14 | from eval.cider.reinforce_cider_scorer import ReinforceCiderScorer
15 | import pdb
16 |
17 |
18 | class ReinforceCider:
19 |
20 | # The batch_ref_sentences will be a small sample of the original corpus, note however that there's no need of
21 | # correspondence of img_ids between img_ids in the corpus and the ones in the batch_ref_sentences, the img_ids
22 | # consistency is required between batch_ref_sentences and batch_test_sentences only.
23 | def __init__(self, corpus, n=4, sigma=6.0):
24 | '''
25 | Corpus represents the collection of reference sentences for each image, this must be a dictionary with image
26 | ids as keys and a list of sentences as value.
27 |
28 | :param corpus: a dictionary with
29 | :param n: number of n-grams
30 | :param sigma: length penalty coefficient
31 | '''
32 | # set cider to sum over 1 to 4-grams
33 | self._n = n
34 | # set the standard deviation parameter for gaussian penalty
35 | self._sigma = sigma
36 | self.cider_scorer = ReinforceCiderScorer(corpus, n=self._n, sigma=self._sigma)
37 |
38 | def compute_score(self, hypo, refs):
39 | """
40 | Main function to compute CIDEr score
41 | :param hypo_for_image (dict) : dictionary with key and value
42 | ref_for_image (dict) : dictionary with key and value
43 | :return: cider (float) : computed CIDEr score for the corpus
44 | """
45 |
46 | # assert(hypo.keys() == refs.keys())
47 |
48 | (score, scores) = self.cider_scorer.compute_score(refs, hypo)
49 |
50 | return score, scores
51 |
52 | def method(self):
53 | return "Reinforce CIDEr"
--------------------------------------------------------------------------------
/eval/cider/reinforce_cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | # ReinforceCIDErScorer is an alternative implementation of CIDEr Scorer according to my personal
6 | # input output format "tastes" and needs --- Jia Cheng
7 |
8 | import copy
9 | from collections import defaultdict
10 | import numpy as np
11 | import pdb
12 | import math
13 |
14 |
15 | def precook(s, n=4, out=False):
16 | """
17 | Takes a string as input and returns an object that can be given to
18 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
19 | can take string arguments as well.
20 | :param s: string : sentence to be converted into ngrams
21 | :param n: int : number of ngrams for which representation is calculated
22 | :return: term frequency vector for occuring ngrams
23 | """
24 | words = s.split()
25 | counts = defaultdict(int)
26 | for k in range(1, n + 1):
27 | for i in range(len(words) - k + 1):
28 | ngram = tuple(words[i:i + k])
29 | counts[ngram] += 1
30 | return counts
31 |
32 |
33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
34 | '''Takes a list of reference sentences for a single segment
35 | and returns an object that encapsulates everything that BLEU
36 | needs to know about them.
37 | :param refs: list of string : reference sentences for some image
38 | :param n: int : number of ngrams for which (ngram) representation is calculated
39 | :return: result (list of dict)
40 | '''
41 | return [precook(ref, n) for ref in refs]
42 |
43 |
44 | def cook_test(test, n=4):
45 | '''Takes a test sentence and returns an object that
46 | encapsulates everything that BLEU needs to know about it.
47 | :param test: list of string : hypothesis sentence for some image
48 | :param n: int : number of ngrams for which (ngram) representation is calculated
49 | :return: result (dict)
50 | '''
51 | return precook(test, n, True)
52 |
53 |
54 | class ReinforceCiderScorer(object):
55 | """CIDEr scorer.
56 | """
57 |
58 | def __init__(self, corpus_crefs, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 |
63 | df_crefs = []
64 | for refs in corpus_crefs:
65 | df_crefs.append(cook_refs(refs))
66 |
67 | self.document_frequency = self.compute_doc_freq(df_crefs)
68 | self.corpus_ref_len = np.log(float(len(df_crefs)))
69 |
70 |
71 | def compute_doc_freq(self, df_crefs):
72 | '''
73 | Compute term frequency for reference data.
74 | This will be used to compute idf (inverse document frequency later)
75 | The term frequency is stored in the object
76 | :return: None
77 | '''
78 | document_frequency = defaultdict(float)
79 | for refs in df_crefs:
80 | # refs, k ref captions of one image
81 | for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
82 | document_frequency[ngram] += 1
83 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
84 | return document_frequency
85 |
86 | def compute_cider(self, refs, tests):
87 |
88 | ctest = []
89 | crefs = []
90 | for idx in range(len(tests)):
91 | test = tests[idx]
92 | ref = refs[idx]
93 | ctest.append(cook_test(test))
94 | crefs.append(cook_refs(ref))
95 |
96 | def counts2vec(cnts):
97 | """
98 | Function maps counts of ngram to vector of tfidf weights.
99 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
100 | The n-th entry of array denotes length of n-grams.
101 | :param cnts:
102 | :param ref_len:
103 | :return: vec (array of dict), norm (array of float), length (int)
104 | """
105 | vec = [defaultdict(float) for _ in range(self.n)]
106 | length = 0
107 | norm = [0.0 for _ in range(self.n)]
108 | for (ngram, term_freq) in cnts.items():
109 | # give word count 1 if it doesn't appear in reference corpus
110 | df = np.log(max(1.0, self.document_frequency[ngram]))
111 | # ngram index
112 | n = len(ngram) - 1
113 | # tf (term_freq) * idf (precomputed idf) for n-grams
114 | vec[n][ngram] = float(term_freq) * (self.corpus_ref_len - df)
115 | # compute norm for the vector. the norm will be used for computing similarity
116 | norm[n] += pow(vec[n][ngram], 2)
117 |
118 | if n == 1:
119 | length += term_freq
120 | norm = [np.sqrt(n) for n in norm]
121 | return vec, norm, length
122 |
123 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
124 | '''
125 | Compute the cosine similarity of two vectors.
126 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
127 | :param vec_ref: array of dictionary for vector corresponding to reference
128 | :param norm_hyp: array of float for vector corresponding to hypothesis
129 | :param norm_ref: array of float for vector corresponding to reference
130 | :param length_hyp: int containing length of hypothesis
131 | :param length_ref: int containing length of reference
132 | :return: array of score for each n-grams cosine similarity
133 | '''
134 | delta = float(length_hyp - length_ref)
135 | # measure consine similarity
136 | val = np.array([0.0 for _ in range(self.n)])
137 | for n in range(self.n):
138 | # ngram
139 | for (ngram, count) in vec_hyp[n].items():
140 | # vrama91 : added clipping
141 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
142 |
143 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
144 | val[n] /= (norm_hyp[n] * norm_ref[n])
145 |
146 | assert (not math.isnan(val[n]))
147 | # vrama91: added a length based gaussian penalty
148 | val[n] *= np.e ** (-(delta ** 2) / (2 * self.sigma ** 2))
149 | return val
150 |
151 | # compute log reference length
152 |
153 | scores = []
154 | for test, refs in zip(ctest, crefs):
155 | # compute vector for test captions
156 | vec, norm, length = counts2vec(test)
157 | # compute vector for ref captions
158 | score = np.array([0.0 for _ in range(self.n)])
159 | for ref in refs:
160 | vec_ref, norm_ref, length_ref = counts2vec(ref)
161 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
162 | # change by vrama91 - mean of ngram scores, instead of sum
163 | score_avg = np.mean(score)
164 | # divide by number of references
165 | score_avg /= len(refs)
166 | # multiply score by 10
167 | score_avg *= 10.0
168 | # append score of an image to the score list
169 | scores.append(score_avg)
170 | return scores
171 |
172 | """
173 | refs must be contained in the dataset for document frequency
174 | """
175 | def compute_score(self, refs, tests, option=None, verbose=0):
176 | # compute idf
177 | # self.compute_doc_freq()
178 | # assert to check document frequency
179 | # assert(len(self.ctest) >= max(self.document_frequency.values()))
180 | # compute cider score
181 | score = self.compute_cider(refs, tests)
182 | # debug
183 | # print score
184 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/eval/eval.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 | from eval.tokenizer.ptbtokenizer import PTBTokenizer
3 | from eval.bleu.bleu import Bleu
4 | from eval.meteor.meteor import Meteor
5 | from eval.rouge.rouge import Rouge
6 | from eval.cider.cider import Cider
7 | from eval.spice.spice import Spice
8 |
9 |
10 | """
11 | I do not own the rights of this code, I just modified it according to my needs.
12 | The original version can be found in:
13 | https://github.com/cocodataset/cocoapi
14 | """
15 | class COCOEvalCap:
16 | def __init__(self, dataset_gts_anns, pred_anns, pred_img_ids, get_stanford_models_path=None):
17 | self.evalImgs = []
18 | self.eval = {}
19 | self.imgToEval = {}
20 | self.dataset_gts_anns = dataset_gts_anns
21 | self.pred_anns = pred_anns
22 | self.pred_img_ids = pred_img_ids
23 |
24 | import subprocess
25 | # print("invoking " + str(get_stanford_models_path))
26 | rc = subprocess.call(get_stanford_models_path)
27 |
28 | def evaluate(self, bleu=True, rouge=True, cider=True, spice=True, meteor=True, verbose=True):
29 | # imgIds = self.coco.getImgIds()
30 | gts = {}
31 | res = {}
32 | for imgId in self.pred_img_ids:
33 | gts[imgId] = self.dataset_gts_anns[imgId]
34 | res[imgId] = self.pred_anns[imgId]
35 |
36 | # =================================================
37 | # Set up scorers
38 | # =================================================
39 | #if verbose:
40 | # print('tokenization...')
41 | tokenizer = PTBTokenizer()
42 | gts = tokenizer.tokenize(gts)
43 | res = tokenizer.tokenize(res)
44 |
45 | # =================================================
46 | # Set up scorers
47 | # =================================================
48 | #if verbose:
49 | # print('setting up scorers...')
50 | scorers = []
51 | if cider:
52 | scorers.append((Cider(), "CIDEr"))
53 | if bleu:
54 | scorers.append((Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]))
55 | if rouge:
56 | scorers.append((Rouge(), "ROUGE_L"))
57 | if spice:
58 | scorers.append((Spice(), "SPICE"))
59 | if meteor:
60 | scorers.append((Meteor(), "METEOR"))
61 | """
62 | scorers = [
63 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
64 | (Rouge(), "ROUGE_L"),
65 | (Cider(), "CIDEr"),
66 | (Spice(), "SPICE"),
67 | (Meteor(), "METEOR"),
68 | ]
69 | """
70 |
71 | # =================================================
72 | # Compute scores
73 | # =================================================
74 | return_scores = []
75 | for scorer, method in scorers:
76 | if verbose:
77 | # print('computing %s score...'%(scorer.method()))
78 | pass
79 | score, scores = scorer.compute_score(gts, res)
80 | if type(method) == list:
81 | for sc, scs, m in zip(score, scores, method):
82 | self.setEval(sc, m)
83 | self.setImgToEvalImgs(scs, gts.keys(), m)
84 | if verbose:
85 | # print("%s: %0.3f"%(m, sc))
86 | pass
87 | return_scores.append((m, round(sc, 4)))
88 | else:
89 | self.setEval(score, method)
90 | self.setImgToEvalImgs(scores, gts.keys(), method)
91 | if verbose:
92 | # print("%s: %0.3f"%(method, score))
93 | pass
94 | return_scores.append((method, round(score, 4)))
95 | self.setEvalImgs()
96 |
97 | return return_scores
98 |
99 | def setEval(self, score, method):
100 | self.eval[method] = score
101 |
102 | def setImgToEvalImgs(self, scores, imgIds, method):
103 | for imgId, score in zip(imgIds, scores):
104 | if not imgId in self.imgToEval:
105 | self.imgToEval[imgId] = {}
106 | self.imgToEval[imgId]["image_id"] = imgId
107 | self.imgToEval[imgId][method] = score
108 |
109 | def setEvalImgs(self):
110 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
--------------------------------------------------------------------------------
/eval/get_stanford_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 | # This script downloads the Stanford CoreNLP models.
3 |
4 | CORENLP=stanford-corenlp-full-2015-12-09
5 | SPICELIB=./spice/lib
6 | JAR=stanford-corenlp-3.6.0
7 |
8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )"
9 | cd $DIR
10 |
11 | if [ -f $SPICELIB/$JAR.jar ]; then
12 | # echo "Found Stanford CoreNLP."
13 | :
14 | else
15 | echo "Downloading..."
16 | wget http://nlp.stanford.edu/software/$CORENLP.zip
17 | echo "Unzipping..."
18 | unzip $CORENLP.zip -d $SPICELIB/
19 | mv $SPICELIB/$CORENLP/$JAR.jar $SPICELIB/
20 | mv $SPICELIB/$CORENLP/$JAR-models.jar $SPICELIB/
21 | rm -f $CORENLP.zip
22 | rm -rf $SPICELIB/$CORENLP/
23 | echo "Done."
24 | fi
25 |
--------------------------------------------------------------------------------
/eval/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/eval/meteor/data/paraphrase-en.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/meteor/data/paraphrase-en.gz
--------------------------------------------------------------------------------
/eval/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/eval/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | # modified according to: https://github.com/tylin/coco-caption/issues/27
7 | # to support python3.5
8 |
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import os
14 | import sys
15 | import subprocess
16 | import threading
17 |
18 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
19 | METEOR_JAR = 'meteor-1.5.jar'
20 |
21 |
22 | # print METEOR_JAR
23 |
24 | class Meteor:
25 |
26 | def __init__(self):
27 | self.env = os.environ
28 | self.env['LC_ALL'] = 'en_US.UTF_8'
29 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR,
30 | '-', '-', '-stdio', '-l', 'en', '-norm']
31 | self.meteor_p = subprocess.Popen(self.meteor_cmd,
32 | cwd=os.path.dirname(os.path.abspath(__file__)),
33 | stdin=subprocess.PIPE,
34 | stdout=subprocess.PIPE,
35 | stderr=subprocess.PIPE,
36 | env=self.env, universal_newlines=True, bufsize=1)
37 | # Used to guarantee thread safety
38 | self.lock = threading.Lock()
39 |
40 | def compute_score(self, gts, res):
41 | assert (gts.keys() == res.keys())
42 | imgIds = sorted(list(gts.keys()))
43 | scores = []
44 |
45 | eval_line = 'EVAL'
46 | self.lock.acquire()
47 | for i in imgIds:
48 | assert (len(res[i]) == 1)
49 | # There's a situation that the prediction is all punctuations
50 | # (see definition of PUNCTUATIONS in pycocoevalcap/tokenizer/ptbtokenizer.py)
51 | # then the prediction will become [''] after tokenization
52 | # which means res[i][0] == '' and self._stat will failed with this input
53 | if len(res[i][0]) == 0:
54 | res[i][0] = 'a'
55 | stat = self._stat(res[i][0], gts[i])
56 | eval_line += ' ||| {}'.format(stat)
57 |
58 | # Send to METEOR
59 | self.meteor_p.stdin.write(eval_line + '\n')
60 |
61 | # Collect segment scores
62 | for i in range(len(imgIds)):
63 | score = float(self.meteor_p.stdout.readline().strip())
64 | scores.append(score)
65 |
66 | # Final score
67 | final_score = float(self.meteor_p.stdout.readline().strip())
68 | self.lock.release()
69 |
70 | return final_score, scores
71 |
72 | def method(self):
73 | return "METEOR"
74 |
75 | def _stat(self, hypothesis_str, reference_list):
76 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
77 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
78 | if sys.version_info[0] == 2: # python2
79 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)).encode('utf-8').strip()
80 | self.meteor_p.stdin.write(str(score_line + b'\n'))
81 | else: # assume python3+
82 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)).strip()
83 | self.meteor_p.stdin.write(score_line + '\n')
84 | return self.meteor_p.stdout.readline().strip()
85 |
86 | def __del__(self):
87 | self.lock.acquire()
88 | self.meteor_p.stdin.close()
89 | self.meteor_p.kill()
90 | self.meteor_p.wait()
91 | self.lock.release()
--------------------------------------------------------------------------------
/eval/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/eval/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | return average_score, np.array(score)
103 |
104 | def method(self):
105 | return "Rouge"
106 |
--------------------------------------------------------------------------------
/eval/spice/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/__init__.py
--------------------------------------------------------------------------------
/eval/spice/cache/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/cache/.gitkeep
--------------------------------------------------------------------------------
/eval/spice/lib/Meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/Meteor-1.5.jar
--------------------------------------------------------------------------------
/eval/spice/lib/SceneGraphParser-1.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/SceneGraphParser-1.0.jar
--------------------------------------------------------------------------------
/eval/spice/lib/ejml-0.23.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/ejml-0.23.jar
--------------------------------------------------------------------------------
/eval/spice/lib/fst-2.47.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/fst-2.47.jar
--------------------------------------------------------------------------------
/eval/spice/lib/guava-19.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/guava-19.0.jar
--------------------------------------------------------------------------------
/eval/spice/lib/hamcrest-core-1.3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/hamcrest-core-1.3.jar
--------------------------------------------------------------------------------
/eval/spice/lib/jackson-core-2.5.3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/jackson-core-2.5.3.jar
--------------------------------------------------------------------------------
/eval/spice/lib/javassist-3.19.0-GA.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/javassist-3.19.0-GA.jar
--------------------------------------------------------------------------------
/eval/spice/lib/json-simple-1.1.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/json-simple-1.1.1.jar
--------------------------------------------------------------------------------
/eval/spice/lib/junit-4.12.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/junit-4.12.jar
--------------------------------------------------------------------------------
/eval/spice/lib/lmdbjni-0.4.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-0.4.6.jar
--------------------------------------------------------------------------------
/eval/spice/lib/lmdbjni-linux64-0.4.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-linux64-0.4.6.jar
--------------------------------------------------------------------------------
/eval/spice/lib/lmdbjni-osx64-0.4.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-osx64-0.4.6.jar
--------------------------------------------------------------------------------
/eval/spice/lib/lmdbjni-win64-0.4.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/lmdbjni-win64-0.4.6.jar
--------------------------------------------------------------------------------
/eval/spice/lib/objenesis-2.4.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/objenesis-2.4.jar
--------------------------------------------------------------------------------
/eval/spice/lib/slf4j-api-1.7.12.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/slf4j-api-1.7.12.jar
--------------------------------------------------------------------------------
/eval/spice/lib/slf4j-simple-1.7.21.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/lib/slf4j-simple-1.7.21.jar
--------------------------------------------------------------------------------
/eval/spice/spice-1.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/spice-1.0.jar
--------------------------------------------------------------------------------
/eval/spice/spice.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import sys
4 | import subprocess
5 | import threading
6 | import json
7 | import numpy as np
8 | import ast
9 | import tempfile
10 | import random
11 |
12 | # Assumes spice.jar is in the same directory as spice.py. Change as needed.
13 | SPICE_JAR = 'spice-1.0.jar'
14 | TEMP_DIR = 'tmp'
15 | CACHE_DIR = 'cache'
16 |
17 |
18 | class Spice:
19 | """
20 | Main Class to compute the SPICE metric
21 | """
22 |
23 | def float_convert(self, obj):
24 | try:
25 | return float(obj)
26 | except:
27 | return np.nan
28 |
29 | def compute_score(self, gts, res):
30 | assert (sorted(gts.keys()) == sorted(res.keys()))
31 | imgIds = sorted(gts.keys())
32 |
33 | # Prepare temp input file for the SPICE scorer
34 | input_data = []
35 | for id in imgIds:
36 | hypo = res[id]
37 | ref = gts[id]
38 |
39 | # Sanity check.
40 | assert (type(hypo) is list)
41 | assert (len(hypo) == 1)
42 | assert (type(ref) is list)
43 | assert (len(ref) >= 1)
44 |
45 | input_data.append({
46 | "image_id": id,
47 | "test": hypo[0],
48 | "refs": ref
49 | })
50 |
51 | cwd = os.path.dirname(os.path.abspath(__file__))
52 | temp_dir = os.path.join(cwd, TEMP_DIR)
53 | if not os.path.exists(temp_dir):
54 | os.makedirs(temp_dir)
55 |
56 | # the generation of random names avoid very unlucky synchronization situations
57 | # with respect to the original implementation
58 |
59 | # python2.7 to python3.5 adaptation
60 | # in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
61 | import time
62 | random.seed(time.time())
63 | random_int_str = str(random.randint(0, 9999999))
64 | in_file_name = random_int_str + '_pid' + str(os.getpid()) + '_in_tmp_file.json'
65 | # print(temp_dir + '/' + in_file_name)
66 | with open(temp_dir + '/' + in_file_name, 'w') as in_file:
67 | json.dump(input_data, in_file, indent=2)
68 |
69 | # Start job
70 | # out_file_name = 'out_tmp_file.tmp'
71 | # with open(temp_dir + '/' + out_file_name, 'w') as out_file:
72 |
73 | out_file_name = random_int_str + '_pid' + str(os.getpid()) + '_out_tmp_file.json'
74 | out_file_path = temp_dir + '/' + out_file_name
75 | # create file
76 | with open(out_file_path, 'w') as f:
77 | f.write('')
78 |
79 | cache_dir = os.path.join(cwd, CACHE_DIR)
80 | if not os.path.exists(cache_dir):
81 | os.makedirs(cache_dir)
82 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR,
83 | temp_dir + '/' + in_file_name,
84 | '-cache', cache_dir,
85 | '-out', out_file_path,
86 | '-subset',
87 | '-silent'
88 | ]
89 | subprocess.check_call(spice_cmd, cwd=os.path.dirname(os.path.abspath(__file__)),
90 | stdout=subprocess.DEVNULL,
91 | stderr=subprocess.DEVNULL)
92 |
93 | # Read and process results
94 | with open(temp_dir + '/' + out_file_name, 'r') as data_file:
95 | results = json.load(data_file)
96 |
97 | os.remove(temp_dir + '/' + in_file_name)
98 | os.remove(temp_dir + '/' + out_file_name)
99 |
100 | imgId_to_scores = {}
101 | spice_scores = []
102 | for item in results:
103 | imgId_to_scores[item['image_id']] = item['scores']
104 | spice_scores.append(self.float_convert(item['scores']['All']['f']))
105 | average_score = np.mean(np.array(spice_scores))
106 | scores = []
107 | for image_id in imgIds:
108 | # Convert none to NaN before saving scores over subcategories
109 | score_set = {}
110 | for category, score_tuple in imgId_to_scores[image_id].items():
111 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()}
112 | scores.append(score_set)
113 | return average_score, scores
114 |
115 | def method(self):
116 | return "SPICE"
117 |
118 |
119 |
--------------------------------------------------------------------------------
/eval/spice/tmp/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/spice/tmp/.gitkeep
--------------------------------------------------------------------------------
/eval/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/eval/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import random
15 | import tempfile
16 | import itertools
17 |
18 | # path to the stanford corenlp jar
19 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
20 |
21 | # punctuations to be removed from the sentences
22 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
23 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
24 |
25 | class PTBTokenizer:
26 | """Python wrapper of Stanford PTBTokenizer"""
27 |
28 | def tokenize(self, captions_for_image):
29 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
30 | 'edu.stanford.nlp.process.PTBTokenizer', \
31 | '-preserveLines', '-lowerCase']
32 |
33 | # ======================================================
34 | # prepare data for PTB Tokenizer
35 | # ======================================================
36 | final_tokenized_captions_for_image = {}
37 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
38 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
39 |
40 | # ======================================================
41 | # save sentences to temporary file
42 | # ======================================================
43 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
44 | tmp_file_name = path_to_jar_dirname + '/temp_file' + str(random.randint(0, 999999)) + '_pid_' + str(os.getpid()) + '.tmp'
45 | with open(tmp_file_name, 'w') as tmp_file:
46 | tmp_file.write(sentences)
47 |
48 | # ======================================================
49 | # tokenize sentence
50 | # ======================================================
51 | cmd.append(tmp_file_name)
52 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
53 | stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
54 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
55 | # print(token_lines)
56 | token_lines = token_lines.decode("utf-8")
57 | lines = token_lines.split('\n')
58 | # remove temp file
59 | if os.path.isfile(tmp_file_name):
60 | os.remove(tmp_file_name)
61 |
62 | # ======================================================
63 | # create dictionary for tokenized captions
64 | # ======================================================
65 | for k, line in zip(image_id, lines):
66 | if not k in final_tokenized_captions_for_image:
67 | final_tokenized_captions_for_image[k] = []
68 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
69 | if w not in PUNCTUATIONS])
70 | final_tokenized_captions_for_image[k].append(tokenized_caption)
71 |
72 | return final_tokenized_captions_for_image
73 |
--------------------------------------------------------------------------------
/eval/tokenizer/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/eval/tokenizer/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/github_ignore_material/raw_data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/github_ignore_material/raw_data/.gitkeep
--------------------------------------------------------------------------------
/github_ignore_material/saves/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/github_ignore_material/saves/.gitkeep
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jia Cheng Hu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/losses/__init__.py
--------------------------------------------------------------------------------
/losses/loss.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class LabelSmoothingLoss(nn.Module):
7 | def __init__(self, smoothing_coeff, rank='cuda:0'):
8 | assert 0.0 <= smoothing_coeff <= 1.0
9 | super().__init__()
10 | self.smoothing_coeff = smoothing_coeff
11 | self.kl_div = nn.KLDivLoss(reduction='none')
12 | self.log_softmax = nn.LogSoftmax(dim=-1)
13 |
14 | self.rank = rank
15 |
16 | def forward(self, pred, target, ignore_index, divide_by_non_zeros=True):
17 | pred = self.log_softmax(pred)
18 |
19 | batch_size, seq_len, num_classes = pred.shape
20 | uniform_confidence = self.smoothing_coeff / (num_classes - 1) # minus one cause of PAD token
21 | confidence = 1 - self.smoothing_coeff
22 | one_hot = torch.full((num_classes,), uniform_confidence).to(self.rank)
23 | model_prob = one_hot.repeat(batch_size, seq_len, 1)
24 | model_prob.scatter_(2, target.unsqueeze(2), confidence)
25 | model_prob.masked_fill_((target == ignore_index).unsqueeze(2), 0)
26 |
27 | tot_loss_tensor = self.kl_div(pred, model_prob)
28 |
29 | # divide the loss of each sequence by the number of non pads
30 | pads_matrix = torch.as_tensor(target == ignore_index)
31 | tot_loss_tensor.masked_fill_(pads_matrix.unsqueeze(2), 0.0)
32 | if divide_by_non_zeros:
33 | num_non_pads = (~pads_matrix).sum().type(torch.cuda.FloatTensor)
34 | tot_loss = tot_loss_tensor.sum() / num_non_pads
35 | else:
36 | tot_loss = tot_loss_tensor.sum()
37 |
38 | return tot_loss
39 |
--------------------------------------------------------------------------------
/losses/reward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from eval.cider.reinforce_cider import ReinforceCider
3 |
4 | import itertools
5 | from utils import language_utils
6 |
7 |
8 | class ReinforceCiderReward:
9 | def __init__(self, training_references, eos_token, num_sampled_captions, rank):
10 | super(ReinforceCiderReward).__init__()
11 | self.rank = rank
12 | self.num_sampled_captions = num_sampled_captions
13 |
14 | preprocessed_training_references = []
15 | for i in range(len(training_references)):
16 | preprocessed_captions = []
17 | for caption in training_references[i]:
18 | # it's faster than invoking the tokenizer
19 | caption = language_utils.lowercase_and_clean_trailing_spaces([caption])
20 | caption = language_utils.add_space_between_non_alphanumeric_symbols(caption)
21 | caption = language_utils.remove_punctuations(caption)
22 | caption = " ".join(caption[0].split() + [eos_token])
23 | preprocessed_captions.append(caption)
24 | preprocessed_training_references.append(preprocessed_captions)
25 | self.training_references = preprocessed_training_references
26 | self.reinforce_cider = ReinforceCider(self.training_references)
27 |
28 | def compute_reward(self, all_images_pred_caption, all_images_logprob, all_images_idx, all_images_base_caption=None):
29 |
30 | batch_size = len(all_images_pred_caption)
31 | num_sampled_captions = len(all_images_pred_caption[0])
32 |
33 | # Important for Correct and Fair results: keep EOS in the Log loss computation
34 | all_images_pred_caption = [' '.join(caption[1:])
35 | for pred_one_image in all_images_pred_caption
36 | for caption in pred_one_image]
37 |
38 | # repeat the references for the number of outputs
39 | all_images_ref_caption = [self.training_references[idx] for idx in all_images_idx]
40 | all_images_ref_caption = list(itertools.chain.from_iterable(itertools.repeat(ref, self.num_sampled_captions)
41 | for ref in all_images_ref_caption))
42 |
43 | cider_result = self.reinforce_cider.compute_score(hypo=all_images_pred_caption,
44 | refs=all_images_ref_caption)
45 | reward = torch.tensor(cider_result[1]).to(self.rank).view(batch_size, num_sampled_captions)
46 |
47 |
48 | if all_images_base_caption is None: # Mean base
49 | reward_base = (reward.sum(dim=-1, keepdim=True) - reward) / (num_sampled_captions - 1)
50 | else: # anything else like Greedy
51 | all_images_base_caption = [' '.join(caption[1:])
52 | for pred_one_image in all_images_base_caption
53 | for caption in pred_one_image]
54 |
55 | base_cider_result = self.reinforce_cider.compute_score(hypo=all_images_base_caption,
56 | refs=all_images_ref_caption)
57 | reward_base = torch.tensor(base_cider_result[1]).to(self.rank).view(batch_size, num_sampled_captions)
58 |
59 | reward_loss = (reward - reward_base) * torch.sum(-all_images_logprob, dim=-1)
60 | reward_loss = reward_loss.mean()
61 | return reward_loss, reward, reward_base
62 |
--------------------------------------------------------------------------------
/models/End_ExpansionNet_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer
5 | from utils.masking import create_pad_mask, create_no_peak_and_pad_mask
6 | from models.captioning_model import CaptioningModel
7 | from models.swin_transformer_mod import SwinTransformer
8 |
9 |
10 | class End_ExpansionNet_v2(CaptioningModel):
11 | def __init__(self,
12 |
13 | # swin transf
14 | swin_img_size, swin_patch_size, swin_in_chans,
15 | swin_embed_dim, swin_depths, swin_num_heads,
16 | swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale,
17 | swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate,
18 | swin_norm_layer, swin_ape, swin_patch_norm,
19 | swin_use_checkpoint,
20 |
21 | # linear_size,
22 | final_swin_dim,
23 |
24 | # captioning
25 | d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec,
26 | output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0):
27 | super(End_ExpansionNet_v2, self).__init__()
28 |
29 | self.swin_transf = SwinTransformer(
30 | img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans,
31 | embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads,
32 | window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale,
33 | drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate,
34 | norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm,
35 | use_checkpoint=swin_use_checkpoint)
36 |
37 | self.output_word2idx = output_word2idx
38 | self.output_idx2word = output_idx2word
39 | self.max_seq_len = max_seq_len
40 |
41 | self.num_exp_dec = num_exp_dec
42 | self.num_exp_enc_list = num_exp_enc_list
43 |
44 | self.N_enc = N_enc
45 | self.N_dec = N_dec
46 | self.d_model = d_model
47 |
48 | self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)])
49 | self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)])
50 |
51 | self.input_embedder_dropout = nn.Dropout(drop_args.enc_input)
52 | self.input_linear = torch.nn.Linear(final_swin_dim, d_model)
53 | self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx))
54 | self.log_softmax = nn.LogSoftmax(dim=-1)
55 |
56 | self.out_enc_dropout = nn.Dropout(drop_args.other)
57 | self.out_dec_dropout = nn.Dropout(drop_args.other)
58 |
59 | self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input)
60 | self.pos_encoder = nn.Embedding(max_seq_len, d_model)
61 |
62 | self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model)
63 | self.enc_reduce_norm = nn.LayerNorm(d_model)
64 | self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model)
65 | self.dec_reduce_norm = nn.LayerNorm(d_model)
66 |
67 | for p in self.parameters():
68 | if p.dim() > 1:
69 | nn.init.xavier_uniform_(p)
70 |
71 | self.trained_steps = 0
72 | self.rank = rank
73 |
74 | self.check_required_attributes()
75 |
76 | def forward_enc(self, enc_input, enc_input_num_pads):
77 |
78 | assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), \
79 | "End to End case have no padding"
80 | x = self.swin_transf(enc_input)
81 |
82 | enc_input = self.input_embedder_dropout(self.input_linear(x))
83 | x = enc_input
84 | enc_input_num_pads = [0] * enc_input.size(0)
85 |
86 | max_num_enc = sum(self.num_exp_enc_list)
87 | pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank)
88 | pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)),
89 | pad_row=[0] * enc_input.size(0),
90 | pad_column=enc_input_num_pads,
91 | rank=self.rank)
92 |
93 | x_list = []
94 | for i in range(self.N_enc):
95 | x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask)
96 | x_list.append(x)
97 | x_list = torch.cat(x_list, dim=-1)
98 | x = x + self.out_enc_dropout(self.enc_reduce_group(x_list))
99 | x = self.enc_reduce_norm(x)
100 |
101 | return x
102 |
103 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
104 | assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), \
105 | "enc_input_num_pads should be no None"
106 |
107 | enc_input_num_pads = [0] * dec_input.size(0)
108 | no_peak_and_pad_mask = create_no_peak_and_pad_mask(
109 | mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)),
110 | num_pads=dec_input_num_pads,
111 | rank=self.rank)
112 | pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)),
113 | pad_row=dec_input_num_pads,
114 | pad_column=enc_input_num_pads,
115 | rank=self.rank)
116 |
117 | y = self.out_embedder(dec_input)
118 | pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank)
119 | pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank)
120 | y = y + self.pos_encoder(pos_y)
121 | y_list = []
122 | for i in range(self.N_dec):
123 | y = self.decoders[i](x=y,
124 | n_indexes=pos_x,
125 | cross_connection_x=cross_input,
126 | input_attention_mask=no_peak_and_pad_mask,
127 | cross_attention_mask=pad_mask)
128 | y_list.append(y)
129 | y_list = torch.cat(y_list, dim=-1)
130 | y = y + self.out_dec_dropout(self.dec_reduce_group(y_list))
131 | y = self.dec_reduce_norm(y)
132 |
133 | y = self.vocab_linear(y)
134 |
135 | if apply_log_softmax:
136 | y = self.log_softmax(y)
137 |
138 | return y
139 |
140 |
141 | def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs,
142 | sos_idx, eos_idx, max_seq_len):
143 |
144 | bs = enc_input.size(0)
145 | x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads)
146 | enc_seq_len = x.size(1)
147 | x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1])
148 |
149 | upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank)
150 | where_is_eos_vector = upperbound_vector.clone()
151 | eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank)
152 | finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int)
153 |
154 | predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1)
155 | predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1)
156 |
157 | dec_input_num_pads = [0]*(bs*num_outputs)
158 | time_step = 0
159 | while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len:
160 | dec_input = predicted_caption
161 | log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True)
162 |
163 | prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step]))
164 | sampled_word_indexes = prob_dist.sample()
165 |
166 | predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1)
167 | predicted_caption_prob = torch.cat((predicted_caption_prob,
168 | log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1)
169 | time_step += 1
170 |
171 | where_is_eos_vector = torch.min(where_is_eos_vector,
172 | upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step))
173 | finished_flag_vector = torch.max(finished_flag_vector,
174 | (sampled_word_indexes == eos_vector).type(torch.IntTensor))
175 |
176 | res_predicted_caption = []
177 | for i in range(bs):
178 | res_predicted_caption.append([])
179 | for j in range(num_outputs):
180 | index = i*num_outputs + j
181 | res_predicted_caption[i].append(
182 | predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist())
183 |
184 | where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1)
185 | arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank)
186 | predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0)
187 | res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1)
188 |
189 | return res_predicted_caption, res_predicted_caption_prob
190 |
--------------------------------------------------------------------------------
/models/ExpansionNet_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer
3 | from utils.masking import create_pad_mask, create_no_peak_and_pad_mask
4 | from models.captioning_model import CaptioningModel
5 |
6 | import torch.nn as nn
7 |
8 |
9 | class ExpansionNet_v2(CaptioningModel):
10 | def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec,
11 | output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0):
12 | super().__init__()
13 | self.output_word2idx = output_word2idx
14 | self.output_idx2word = output_idx2word
15 | self.max_seq_len = max_seq_len
16 |
17 | self.num_exp_dec = num_exp_dec
18 | self.num_exp_enc_list = num_exp_enc_list
19 |
20 | self.N_enc = N_enc
21 | self.N_dec = N_dec
22 | self.d_model = d_model
23 |
24 | self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)])
25 | self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)])
26 |
27 | self.input_embedder_dropout = nn.Dropout(drop_args.enc_input)
28 | self.input_linear = torch.nn.Linear(img_feature_dim, d_model)
29 | self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx))
30 | self.log_softmax = nn.LogSoftmax(dim=-1)
31 |
32 | self.out_enc_dropout = nn.Dropout(drop_args.other)
33 | self.out_dec_dropout = nn.Dropout(drop_args.other)
34 |
35 | self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input)
36 | self.pos_encoder = nn.Embedding(max_seq_len, d_model)
37 |
38 | self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model)
39 | self.enc_reduce_norm = nn.LayerNorm(d_model)
40 | self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model)
41 | self.dec_reduce_norm = nn.LayerNorm(d_model)
42 |
43 | for p in self.parameters():
44 | if p.dim() > 1:
45 | nn.init.xavier_uniform_(p)
46 |
47 | self.trained_steps = 0
48 | self.rank = rank
49 |
50 | def forward_enc(self, enc_input, enc_input_num_pads):
51 |
52 | x = self.input_embedder_dropout(self.input_linear(enc_input.float()))
53 |
54 | sum_num_enc = sum(self.num_exp_enc_list)
55 | pos_x = torch.arange(sum_num_enc).unsqueeze(0).expand(enc_input.size(0), sum_num_enc).to(self.rank)
56 | pad_mask = create_pad_mask(mask_size=(enc_input.size(0), sum_num_enc, enc_input.size(1)),
57 | pad_row=[0] * enc_input.size(0),
58 | pad_column=enc_input_num_pads,
59 | rank=self.rank)
60 |
61 | x_list = []
62 | for i in range(self.N_enc):
63 | x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask)
64 | x_list.append(x)
65 | x_list = torch.cat(x_list, dim=-1)
66 | x = x + self.out_enc_dropout(self.enc_reduce_group(x_list))
67 | x = self.enc_reduce_norm(x)
68 | return x
69 |
70 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
71 |
72 | no_peak_and_pad_mask = create_no_peak_and_pad_mask(
73 | mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)),
74 | num_pads=torch.tensor(dec_input_num_pads),
75 | rank=self.rank)
76 |
77 | pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)),
78 | pad_row=torch.tensor(dec_input_num_pads),
79 | pad_column=torch.tensor(enc_input_num_pads),
80 | rank=self.rank)
81 |
82 | y = self.out_embedder(dec_input)
83 | pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank)
84 | pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank)
85 | y = y + self.pos_encoder(pos_y)
86 | y_list = []
87 | for i in range(self.N_dec):
88 | y = self.decoders[i](x=y,
89 | n_indexes=pos_x,
90 | cross_connection_x=cross_input,
91 | input_attention_mask=no_peak_and_pad_mask,
92 | cross_attention_mask=pad_mask)
93 | y_list.append(y)
94 | y_list = torch.cat(y_list, dim=-1)
95 | y = y + self.out_dec_dropout(self.dec_reduce_group(y_list))
96 | y = self.dec_reduce_norm(y)
97 |
98 | y = self.vocab_linear(y)
99 |
100 | if apply_log_softmax:
101 | y = self.log_softmax(y)
102 |
103 | return y
104 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/models/__init__.py
--------------------------------------------------------------------------------
/models/captioning_model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class CaptioningModel(nn.Module):
8 | def __init__(self):
9 | super(CaptioningModel, self).__init__()
10 | # mandatory attributes
11 | # rank: to enable multiprocessing
12 | self.rank = None
13 |
14 | def check_required_attributes(self):
15 | if self.rank is None:
16 | raise NotImplementedError("Subclass must assign the rank integer according to the GPU group")
17 |
18 | def forward_enc(self, enc_input, enc_input_num_pads):
19 | raise NotImplementedError
20 |
21 | def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
22 | raise NotImplementedError
23 |
24 | def forward(self, enc_x, dec_x=None,
25 | enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False,
26 | mode='forward', **kwargs):
27 | if mode == 'forward':
28 | x = self.forward_enc(enc_x, enc_x_num_pads)
29 | y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax)
30 | return y
31 | else:
32 | assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \
33 | 'sos and eos must be provided in case of batch sampling or beam search'
34 | sos_idx = kwargs.get('sos_idx', -999)
35 | eos_idx = kwargs.get('eos_idx', -999)
36 | if mode == 'beam_search':
37 | beam_size_arg = kwargs.get('beam_size', 5)
38 | how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1)
39 | beam_max_seq_len = kwargs.get('beam_max_seq_len', 20)
40 | sample_or_max = kwargs.get('sample_or_max', 'max')
41 | out_classes, out_logprobs = self.beam_search(
42 | enc_x, enc_x_num_pads,
43 | beam_size=beam_size_arg,
44 | sos_idx=sos_idx,
45 | eos_idx=eos_idx,
46 | how_many_outputs=how_many_outputs_per_beam,
47 | max_seq_len=beam_max_seq_len,
48 | sample_or_max=sample_or_max)
49 | return out_classes, out_logprobs
50 | if mode == 'sampling':
51 | how_many_outputs = kwargs.get('how_many_outputs', 1)
52 | sample_max_seq_len = kwargs.get('sample_max_seq_len', 20)
53 | out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction(
54 | enc_x, enc_x_num_pads, num_outputs=how_many_outputs,
55 | sos_idx=sos_idx, eos_idx=eos_idx,
56 | max_seq_len=sample_max_seq_len)
57 | return out_classes, out_logprobs
58 |
59 | def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs,
60 | sos_idx, eos_idx, max_seq_len):
61 | bs, enc_seq_len, _ = enc_input.shape
62 |
63 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)]
64 |
65 | x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads)
66 | x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1])
67 |
68 | upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank)
69 | where_is_eos_vector = upperbound_vector.clone()
70 | eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank)
71 | finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int)
72 |
73 | predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1)
74 | predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1)
75 |
76 | dec_input_num_pads = [0]*(bs*num_outputs)
77 | time_step = 0
78 | while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len:
79 | dec_input = predicted_caption
80 | log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True)
81 |
82 | prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step]))
83 | sampled_word_indexes = prob_dist.sample()
84 |
85 | predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1)
86 | predicted_caption_prob = torch.cat((predicted_caption_prob,
87 | log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1)
88 | time_step += 1
89 |
90 | where_is_eos_vector = torch.min(where_is_eos_vector,
91 | upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step))
92 | finished_flag_vector = torch.max(finished_flag_vector,
93 | (sampled_word_indexes == eos_vector).type(torch.IntTensor))
94 |
95 | # remove the elements that come after the first eos from the sequence
96 | res_predicted_caption = []
97 | for i in range(bs):
98 | res_predicted_caption.append([])
99 | for j in range(num_outputs):
100 | index = i*num_outputs + j
101 | res_predicted_caption[i].append(
102 | predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist())
103 |
104 | where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1)
105 | arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank)
106 | predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0)
107 | res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1)
108 |
109 | return res_predicted_caption, res_predicted_caption_prob
110 |
111 | def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx,
112 | beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',):
113 | """
114 | TO-DO (Maybe?): The code is not very elegant (can be shorter) and optimized.
115 | E.g. caching can be implemented for a slight inference time improvement.
116 | However, I think it would become less readable / friendly, so not sure if it is worth it.
117 | """
118 | assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width"
119 | assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'"
120 | bs = enc_input.shape[0]
121 |
122 | cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads)
123 |
124 | # init: ------------------------------------------------------------------
125 | init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank)
126 | init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank)
127 | log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads,
128 | dec_input=init_dec_class, dec_input_num_pads=[0] * bs,
129 | apply_log_softmax=True)
130 | if sample_or_max == 'max':
131 | _, topi = torch.topk(log_probs, k=beam_size, sorted=True)
132 | else: # sample
133 | topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False)
134 | topi = topi.unsqueeze(1)
135 |
136 | init_dec_class = init_dec_class.repeat(1, beam_size)
137 | init_dec_class = init_dec_class.unsqueeze(-1)
138 | top_beam_size_class = topi.transpose(-2, -1)
139 | init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1)
140 |
141 | init_dec_logprob = init_dec_logprob.repeat(1, beam_size)
142 | init_dec_logprob = init_dec_logprob.unsqueeze(-1)
143 | top_beam_size_logprob = log_probs.gather(dim=-1, index=topi)
144 | top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1)
145 | init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1)
146 |
147 | bs, enc_seq_len, d_model = cross_enc_output.shape
148 | cross_enc_output = cross_enc_output.unsqueeze(1)
149 | cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1)
150 | cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous()
151 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)]
152 |
153 | # loop: -----------------------------------------------------------------
154 | loop_dec_classes = init_dec_class
155 | loop_dec_logprobs = init_dec_logprob
156 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
157 |
158 | loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank)
159 |
160 | for time_step in range(2, max_seq_len):
161 | loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous()
162 |
163 | log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads,
164 | dec_input=loop_dec_classes,
165 | dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(),
166 | apply_log_softmax=True)
167 | if sample_or_max == 'max':
168 | _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True)
169 | else: # sample
170 | topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size,
171 | replacement=False)
172 |
173 | top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size)
174 |
175 | top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi)
176 | top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size)
177 |
178 | # each sequence have now its best prediction, but some sequence may have already been terminated with EOS,
179 | # in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value
180 | # are set to zero
181 | there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \
182 | sum(dim=-1, keepdims=True).type(torch.bool)
183 |
184 | # if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be
185 | # straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0
186 | # one arbitrary candidate word probability so the sequence probability is unchanged but it
187 | # can still be discarded when a better candidate sequence is found
188 | top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0)
189 | top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0)
190 |
191 | comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs
192 |
193 | comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size)
194 | _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True)
195 | which_sequence = topi // beam_size
196 | which_word = topi % beam_size
197 |
198 | loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1)
199 | loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1)
200 |
201 | bs_idxes = torch.arange(bs).unsqueeze(-1)
202 | new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]]
203 | new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]]
204 |
205 | which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]]
206 | which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[
207 | [bs_idxes, which_sequence]]
208 | which_word = which_word.unsqueeze(-1)
209 |
210 | lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1,
211 | index=which_word)
212 | lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word)
213 |
214 | new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1)
215 | new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1)
216 | loop_dec_classes = new_loop_dec_classes
217 | loop_dec_logprobs = new_loop_dec_logprobs
218 |
219 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
220 |
221 | # -----------------------update loop_num_elem_vector ----------------------------
222 | loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size)
223 | there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \
224 | sum(dim=-1).type(torch.bool).view(bs * beam_size)
225 | loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int)))
226 |
227 | if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size):
228 | break
229 |
230 | # sort out the best result
231 | loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1)
232 | _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size)
233 | res_caption_pred = [[] for _ in range(bs)]
234 | res_caption_logprob = [[] for _ in range(bs)]
235 | for i in range(bs):
236 | for j in range(how_many_outputs):
237 | idx = topi[i, j].item()
238 | res_caption_pred[i].append(
239 | loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist())
240 | res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]])
241 |
242 | flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]]
243 | flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True)
244 | res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1)
245 |
246 | return res_caption_pred, res_caption_logprob
247 |
--------------------------------------------------------------------------------
/models/ensemble_captioning_model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from models.captioning_model import CaptioningModel
5 |
6 |
7 | class EsembleCaptioningModel(CaptioningModel):
8 | def __init__(self, models_list, rank):
9 | super().__init__()
10 | self.num_models = len(models_list)
11 | self.models_list = models_list
12 | self.rank = rank
13 |
14 | self.dummy_linear = nn.Linear(1, 1)
15 |
16 | for model in self.models_list:
17 | model.eval()
18 |
19 | def forward(self, enc_x, dec_x=None,
20 | enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False,
21 | mode='beam_search', **kwargs):
22 | assert (mode == 'beam_search'), "this class supports only beam search."
23 | sos_idx = kwargs.get('sos_idx', -999)
24 | eos_idx = kwargs.get('eos_idx', -999)
25 | if mode == 'beam_search':
26 | beam_size_arg = kwargs.get('beam_size', 5)
27 | how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1)
28 | beam_max_seq_len = kwargs.get('beam_max_seq_len', 20)
29 | sample_or_max = kwargs.get('sample_or_max', 'max')
30 | out_classes, out_logprobs = self.ensemble_beam_search(
31 | enc_x, enc_x_num_pads,
32 | beam_size=beam_size_arg,
33 | sos_idx=sos_idx,
34 | eos_idx=eos_idx,
35 | how_many_outputs=how_many_outputs_per_beam,
36 | max_seq_len=beam_max_seq_len,
37 | sample_or_max=sample_or_max)
38 | return out_classes, out_logprobs
39 |
40 | def forward_enc(self, enc_input, enc_input_num_pads):
41 | x_outputs_list = []
42 | for i in range(self.num_models):
43 | x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads)
44 | x_outputs_list.append(x_outputs)
45 | return x_outputs_list
46 |
47 | def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
48 |
49 | import torch.nn.functional as F
50 | y_outputs = []
51 | for i in range(self.num_models):
52 | y_outputs.append(
53 | F.softmax(self.models_list[i].forward_dec(
54 | cross_input_list[i], enc_input_num_pads,
55 | dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1))
56 | avg = torch.cat(y_outputs, dim=0).mean(dim=0).log()
57 |
58 | return avg
59 |
60 | # quite unclean coding, to be re-factored in the future...
61 | # since it's a bit similar to the single model case
62 | def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx,
63 | beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',):
64 | assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width"
65 | assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'"
66 | bs = enc_input.shape[0]
67 |
68 | # the cross_dec_input is computed once
69 | cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads)
70 |
71 | # init: ------------------------------------------------------------------
72 | init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank)
73 | init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank)
74 | log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads,
75 | dec_input=init_dec_class, dec_input_num_pads=[0] * bs,
76 | apply_log_softmax=True)
77 | if sample_or_max == 'max':
78 | _, topi = torch.topk(log_probs, k=beam_size, sorted=True)
79 | else: # sample
80 | topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False)
81 | topi = topi.unsqueeze(1)
82 |
83 | init_dec_class = init_dec_class.repeat(1, beam_size)
84 | init_dec_class = init_dec_class.unsqueeze(-1)
85 | top_beam_size_class = topi.transpose(-2, -1)
86 | init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1)
87 |
88 | init_dec_logprob = init_dec_logprob.repeat(1, beam_size)
89 | init_dec_logprob = init_dec_logprob.unsqueeze(-1)
90 | top_beam_size_logprob = log_probs.gather(dim=-1, index=topi)
91 | top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1)
92 | init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1)
93 |
94 | tmp_cross_enc_output_list = []
95 | for cross_enc_output in cross_enc_output_list:
96 | bs, enc_seq_len, d_model = cross_enc_output.shape
97 | cross_enc_output = cross_enc_output.unsqueeze(1)
98 | cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1)
99 | cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous()
100 | tmp_cross_enc_output_list.append(cross_enc_output)
101 | cross_enc_output_list = tmp_cross_enc_output_list
102 | enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)]
103 |
104 | loop_dec_classes = init_dec_class
105 | loop_dec_logprobs = init_dec_logprob
106 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
107 |
108 | loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank)
109 |
110 | for time_step in range(2, max_seq_len):
111 | loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous()
112 |
113 | log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads,
114 | dec_input=loop_dec_classes,
115 | dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(),
116 | apply_log_softmax=True)
117 | if sample_or_max == 'max':
118 | _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True)
119 | else: # sample
120 | topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size,
121 | replacement=False)
122 |
123 | top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size)
124 |
125 | top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi)
126 | top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size)
127 |
128 | there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \
129 | sum(dim=-1, keepdims=True).type(torch.bool)
130 |
131 | top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0)
132 | top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0)
133 |
134 | comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs
135 |
136 | comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size)
137 | _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True)
138 | which_sequence = topi // beam_size
139 | which_word = topi % beam_size
140 |
141 | loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1)
142 | loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1)
143 |
144 | bs_idxes = torch.arange(bs).unsqueeze(-1)
145 | new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]]
146 | new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]]
147 |
148 | which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]]
149 | which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[
150 | [bs_idxes, which_sequence]]
151 | which_word = which_word.unsqueeze(-1)
152 |
153 | lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1,
154 | index=which_word)
155 | lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word)
156 |
157 | new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1)
158 | new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1)
159 | loop_dec_classes = new_loop_dec_classes
160 | loop_dec_logprobs = new_loop_dec_logprobs
161 |
162 | loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
163 |
164 | loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size)
165 | there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \
166 | sum(dim=-1).type(torch.bool).view(bs * beam_size)
167 | loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int)))
168 |
169 | if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size):
170 | break
171 |
172 | loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1)
173 | _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size)
174 | res_caption_pred = [[] for _ in range(bs)]
175 | res_caption_logprob = [[] for _ in range(bs)]
176 | for i in range(bs):
177 | for j in range(how_many_outputs):
178 | idx = topi[i, j].item()
179 | res_caption_pred[i].append(
180 | loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist())
181 | res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]])
182 |
183 | flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]]
184 | flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True)
185 | res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1)
186 |
187 | return res_caption_pred, res_caption_logprob
188 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 |
9 | class EmbeddingLayer(nn.Module):
10 | def __init__(self, vocab_size, d_model, dropout_perc):
11 | super(EmbeddingLayer, self).__init__()
12 | self.dropout = nn.Dropout(dropout_perc)
13 | self.embed = nn.Embedding(vocab_size, d_model)
14 | self.d_model = d_model
15 |
16 | def forward(self, x):
17 | return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model))
18 |
19 |
20 | class StaticExpansionBlock(nn.Module):
21 | def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps):
22 | super().__init__()
23 | self.d_model = d_model
24 | self.num_enc_exp_list = num_enc_exp_list
25 |
26 | self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
27 | self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
28 |
29 | self.key_embed = nn.Linear(d_model, d_model)
30 | self.class_a_embed = nn.Linear(d_model, d_model)
31 | self.class_b_embed = nn.Linear(d_model, d_model)
32 |
33 | self.selector_embed = nn.Linear(d_model, d_model)
34 |
35 | self.dropout_class_a_fw = nn.Dropout(dropout_perc)
36 | self.dropout_class_b_fw = nn.Dropout(dropout_perc)
37 |
38 | self.dropout_class_a_bw = nn.Dropout(dropout_perc)
39 | self.dropout_class_b_bw = nn.Dropout(dropout_perc)
40 |
41 | self.Z_dropout = nn.Dropout(dropout_perc)
42 |
43 | self.eps = eps
44 |
45 | def forward(self, x, n_indexes, mask):
46 | bs, enc_len, _ = x.shape
47 |
48 | query_exp = self.query_exp_vectors(n_indexes)
49 | bias_exp = self.bias_exp_vectors(n_indexes)
50 | x_key = self.key_embed(x)
51 |
52 | z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / (self.d_model ** 0.5)
53 | z = self.Z_dropout(z)
54 |
55 | class_a_fw = F.relu(z)
56 | class_b_fw = F.relu(-z)
57 | class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0)
58 | class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0)
59 | class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
60 | class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
61 |
62 | class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp
63 | class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp
64 | class_a = self.dropout_class_a_fw(class_a)
65 | class_b = self.dropout_class_b_fw(class_b)
66 |
67 | class_a_bw = F.relu(z.transpose(-2, -1))
68 | class_b_bw = F.relu(-z.transpose(-2, -1))
69 |
70 | accum = 0
71 | class_a_bw_list = []
72 | class_b_bw_list = []
73 | for j in range(len(self.num_enc_exp_list)):
74 | from_idx = accum
75 | to_idx = accum + self.num_enc_exp_list[j]
76 | accum += self.num_enc_exp_list[j]
77 | class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
78 | class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
79 | class_a_bw = torch.cat(class_a_bw_list, dim=-1)
80 | class_b_bw = torch.cat(class_b_bw_list, dim=-1)
81 |
82 | class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list)
83 | class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list)
84 | class_a = self.dropout_class_a_bw(class_a)
85 | class_b = self.dropout_class_b_bw(class_b)
86 |
87 | selector = torch.sigmoid(self.selector_embed(x))
88 | x_result = selector * class_a + (1 - selector) * class_b
89 |
90 | return x_result
91 |
92 |
93 | class EncoderLayer(nn.Module):
94 | def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9):
95 | super().__init__()
96 | self.norm_1 = nn.LayerNorm(d_model)
97 | self.norm_2 = nn.LayerNorm(d_model)
98 | self.dropout_1 = nn.Dropout(dropout_perc)
99 | self.dropout_2 = nn.Dropout(dropout_perc)
100 |
101 | self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps)
102 | self.ff = FeedForward(d_model, d_ff, dropout_perc)
103 |
104 | def forward(self, x, n_indexes, mask):
105 | x2 = self.norm_1(x)
106 | x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask))
107 | x2 = self.norm_2(x)
108 | x = x + self.dropout_2(self.ff(x2))
109 | return x
110 |
111 |
112 | class DynamicExpansionBlock(nn.Module):
113 | def __init__(self, d_model, num_exp, dropout_perc, eps):
114 | super().__init__()
115 | self.d_model = d_model
116 |
117 | self.num_exp = num_exp
118 | self.cond_embed = nn.Linear(d_model, d_model)
119 |
120 | self.query_exp_vectors = nn.Embedding(self.num_exp, d_model)
121 | self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model)
122 |
123 | self.key_linear = nn.Linear(d_model, d_model)
124 | self.class_a_embed = nn.Linear(d_model, d_model)
125 | self.class_b_embed = nn.Linear(d_model, d_model)
126 |
127 | self.selector_embed = nn.Linear(d_model, d_model)
128 |
129 | self.dropout_class_a_fw = nn.Dropout(dropout_perc)
130 | self.dropout_class_b_fw = nn.Dropout(dropout_perc)
131 | self.dropout_class_a_bw = nn.Dropout(dropout_perc)
132 | self.dropout_class_b_bw = nn.Dropout(dropout_perc)
133 |
134 | self.Z_dropout = nn.Dropout(dropout_perc)
135 |
136 | self.eps = eps
137 |
138 | def forward(self, x, n_indexes, mask):
139 | bs, dec_len, _ = x.shape
140 |
141 | cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model)
142 | query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1)
143 | bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1)
144 | query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
145 | bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
146 |
147 | x_key = self.key_linear(x)
148 | z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / (self.d_model ** 0.5)
149 | z = self.Z_dropout(z)
150 |
151 | mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \
152 | view(bs, dec_len * self.num_exp, dec_len)
153 |
154 | class_a_fw = F.relu(z)
155 | class_b_fw = F.relu(-z)
156 | class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0)
157 | class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0)
158 | class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
159 | class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
160 | class_a = torch.matmul(class_a_fw, self.class_a_embed(x))
161 | class_b = torch.matmul(class_b_fw, self.class_b_embed(x))
162 | class_a = self.dropout_class_a_fw(class_a)
163 | class_b = self.dropout_class_b_fw(class_b)
164 |
165 | mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \
166 | view(bs, dec_len, dec_len * self.num_exp)
167 |
168 | class_a_bw = F.relu(z.transpose(-2, -1))
169 | class_b_bw = F.relu(-z.transpose(-2, -1))
170 | class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0)
171 | class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0)
172 | class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps)
173 | class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps)
174 | class_a = torch.matmul(class_a_bw, class_a + bias_exp)
175 | class_b = torch.matmul(class_b_bw, class_b + bias_exp)
176 | class_a = self.dropout_class_a_bw(class_a)
177 | class_b = self.dropout_class_b_bw(class_b)
178 |
179 | selector = torch.sigmoid(self.selector_embed(x))
180 | x_result = selector * class_a + (1 - selector) * class_b
181 |
182 | return x_result
183 |
184 |
185 | class DecoderLayer(nn.Module):
186 | def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9):
187 | super().__init__()
188 | self.norm_1 = nn.LayerNorm(d_model)
189 | self.norm_2 = nn.LayerNorm(d_model)
190 | self.norm_3 = nn.LayerNorm(d_model)
191 |
192 | self.dropout_1 = nn.Dropout(dropout_perc)
193 | self.dropout_2 = nn.Dropout(dropout_perc)
194 | self.dropout_3 = nn.Dropout(dropout_perc)
195 |
196 | self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc)
197 | self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps)
198 | self.ff = FeedForward(d_model, d_ff, dropout_perc)
199 |
200 | def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask):
201 |
202 | # Pre-LayerNorm
203 | x2 = self.norm_1(x)
204 | x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask))
205 |
206 | x2 = self.norm_2(x)
207 | x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x,
208 | mask=cross_attention_mask))
209 |
210 | x2 = self.norm_3(x)
211 | x = x + self.dropout_3(self.ff(x2))
212 | return x
213 |
214 |
215 |
216 | class MultiHeadAttention(nn.Module):
217 | def __init__(self, d_model, num_heads, dropout_perc):
218 | super(MultiHeadAttention, self).__init__()
219 | assert d_model % num_heads == 0, "num heads must be multiple of d_model"
220 |
221 | self.d_model = d_model
222 | self.d_k = int(d_model / num_heads)
223 | self.num_heads = num_heads
224 |
225 | self.Wq = nn.Linear(d_model, self.d_k * num_heads)
226 | self.Wk = nn.Linear(d_model, self.d_k * num_heads)
227 | self.Wv = nn.Linear(d_model, self.d_k * num_heads)
228 |
229 | self.out_linear = nn.Linear(d_model, d_model)
230 |
231 | def forward(self, q, k, v, mask=None):
232 | batch_size, q_seq_len, _ = q.shape
233 | k_seq_len = k.size(1)
234 | v_seq_len = v.size(1)
235 |
236 | k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k)
237 | q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k)
238 | v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k)
239 |
240 | k_proj = k_proj.transpose(2, 1)
241 | q_proj = q_proj.transpose(2, 1)
242 | v_proj = v_proj.transpose(2, 1)
243 |
244 | sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2))
245 | sim_scores = sim_scores / self.d_k ** 0.5
246 |
247 | if mask is not None:
248 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
249 | sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4)
250 | sim_scores = F.softmax(input=sim_scores, dim=-1)
251 |
252 | attention_applied = torch.matmul(sim_scores, v_proj)
253 | attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\
254 | .view(batch_size, q_seq_len, self.d_model)
255 |
256 | out = self.out_linear(attention_applied_concatenated)
257 | return out
258 |
259 |
260 | class FeedForward(nn.Module):
261 | def __init__(self, d_model, d_ff, dropout_perc):
262 | super(FeedForward, self).__init__()
263 | self.linear_1 = nn.Linear(d_model, d_ff)
264 | self.dropout = nn.Dropout(dropout_perc)
265 | self.linear_2 = nn.Linear(d_ff, d_model)
266 |
267 | def forward(self, x):
268 | x = self.dropout(F.relu(self.linear_1(x)))
269 | x = self.linear_2(x)
270 | return x
271 |
--------------------------------------------------------------------------------
/onnx4tensorrt/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Although this package provide the scripts and implementations to
3 | support ONNX conversion and deploy on TensorRT.
4 |
5 | It kinda breaks the "DRY" principle" but the the model and backbone need some tweak
6 | and more attention than the pure pytorch counterpart. Since issue related to the
7 | particular contexts of ONNX and TensorRT may be raised in the future, it's safer
8 | to separate the two versions for the moment.
9 | """
--------------------------------------------------------------------------------
/onnx4tensorrt/convert2onnx.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import argparse
5 | import pickle
6 | import copy
7 | import onnx
8 | from argparse import Namespace
9 |
10 |
11 | import functools
12 | print = functools.partial(print, flush=True)
13 |
14 | from utils.saving_utils import partially_load_state_dict
15 | from utils.image_utils import preprocess_image
16 | from utils.language_utils import tokens2description
17 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import create_pad_mask, create_no_peak_and_pad_mask
18 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import End_ExpansionNet_v2_ONNX_TensorRT, \
19 | NUM_FEATURES, MAX_DECODE_STEPS
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser(description='Conversion PyTorch to ONNX')
24 | parser.add_argument('--model_dim', type=int, default=512)
25 | parser.add_argument('--N_enc', type=int, default=3)
26 | parser.add_argument('--N_dec', type=int, default=3)
27 | parser.add_argument('--vocab_path', type=str, default='./demo_material/demo_coco_tokens.pickle')
28 | parser.add_argument('--max_seq_len', type=int, default=74)
29 | parser.add_argument('--image_path_1', type=str, default='./demo_material/tatin.jpg')
30 | parser.add_argument('--image_path_2', type=str, default='./demo_material/micheal.jpg')
31 | parser.add_argument('--load_model_path', type=str, default='./github_ignore_material/saves/rf_model.pth')
32 | parser.add_argument('--output_onnx_path', type=str, default='./rf_model.onnx')
33 | parser.add_argument('--onnx_simplify', type=bool, default=False)
34 | parser.add_argument('--onnx_runtime_test', type=bool, default=False)
35 | parser.add_argument('--onnx_tensorrt_test', type=bool, default=False)
36 | parser.add_argument('--max_worker_size', type=int, default=10000)
37 | parser.add_argument('--onnx_opset', type=int, default=14)
38 | # parser.add_argument('--beam_size', type=int, default=5)
39 |
40 | args = parser.parse_args()
41 |
42 | with open(args.vocab_path, 'rb') as f:
43 | coco_tokens = pickle.load(f)
44 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']]
45 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']]
46 |
47 | # test the generalization of the graph using two images
48 | img_size = 384
49 | image_1 = preprocess_image(args.image_path_1, img_size)
50 | image_2 = preprocess_image(args.image_path_2, img_size)
51 |
52 | drop_args = Namespace(enc=0.0,
53 | dec=0.0,
54 | enc_input=0.0,
55 | dec_input=0.0,
56 | other=0.0)
57 | enc_exp_list = [32, 64, 128, 256, 512]
58 | dec_exp = 16
59 | num_heads = 8
60 | model = End_ExpansionNet_v2_ONNX_TensorRT(
61 | swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
62 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
63 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
64 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0,
65 | swin_norm_layer=torch.nn.LayerNorm, swin_patch_norm=True,
66 |
67 | d_model=args.model_dim, N_enc=args.N_enc,
68 | N_dec=args.N_dec, num_heads=num_heads, ff=2048,
69 | num_exp_enc_list=enc_exp_list,
70 | num_exp_dec=16,
71 | output_word2idx=coco_tokens['word2idx_dict'],
72 | output_idx2word=coco_tokens['idx2word_list'],
73 | max_seq_len=args.max_seq_len, drop_args=drop_args)
74 |
75 | print("Loading model...")
76 | checkpoint = torch.load(args.load_model_path)
77 | partially_load_state_dict(model, checkpoint['model_state_dict'])
78 |
79 | print("===============================================")
80 | print("|| ONNX conversion ||")
81 | print("===============================================")
82 |
83 | # Masks creation - - - - -
84 | batch_size = 1
85 | enc_mask = create_pad_mask(mask_size=(batch_size, sum(enc_exp_list), NUM_FEATURES),
86 | pad_row=[0], pad_column=[0]).contiguous()
87 | no_peak_mask = create_no_peak_and_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS),
88 | num_pads=[0]).contiguous()
89 | cross_mask = create_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, NUM_FEATURES),
90 | pad_row=[0], pad_column=[0]).contiguous()
91 | # contrary to the other masks, we put 1 in correspondence to the values to be masked
92 | cross_mask = 1 - cross_mask
93 |
94 | fw_dec_mask = no_peak_mask.unsqueeze(2).expand(batch_size, MAX_DECODE_STEPS,
95 | dec_exp, MAX_DECODE_STEPS).contiguous(). \
96 | view(batch_size, MAX_DECODE_STEPS * dec_exp, MAX_DECODE_STEPS)
97 |
98 | bw_dec_mask = no_peak_mask.unsqueeze(-1).expand(batch_size,
99 | MAX_DECODE_STEPS, MAX_DECODE_STEPS, dec_exp).contiguous(). \
100 | view(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS * dec_exp)
101 |
102 | atten_mask = cross_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
103 |
104 | # - - - - - - - - - -
105 |
106 | print("Exporting...")
107 | model.eval()
108 | my_script = torch.jit.script(model)
109 | torch.onnx.export(
110 | my_script,
111 | (image_1, torch.tensor([sos_idx]), enc_mask, fw_dec_mask, bw_dec_mask, atten_mask),
112 | args.output_onnx_path,
113 | input_names=['enc_x', 'sos_idx', 'enc_mask', 'fw_dec_mask', 'bw_dec_mask', 'cross_mask'],
114 | output_names=['pred', 'logprobs'],
115 | export_params=True,
116 | opset_version=args.onnx_opset)
117 | print("ONNX graph conversion done. Destination: " + args.output_onnx_path)
118 | onnx_model_fp32 = onnx.load(args.output_onnx_path)
119 | onnx.checker.check_model(onnx_model_fp32)
120 | print("ONNX graph checked.")
121 |
122 | # TO-DO: FP16 code was written but does not work for reason yet unknown,
123 | # tests were made on model trained in FP32
124 | # what if the model was trained in FP16 instead?
125 |
126 | # from onnxconverter_common import float16
127 | # onnx_model_fp16 = float16.convert_float_to_float16(onnx_model_fp32, op_block_list=["Topk", "Normalizer"])
128 | # onnx.save(onnx_model_fp16, args.output_onnx_path + '_fp16')
129 | # print("ONNX graph FP16 version done. Destination: " + args.output_onnx_path + '_fp16')
130 |
131 | if args.onnx_simplify:
132 | print("===============================================")
133 | print("|| ONNX graph simplifcation phase ||")
134 | print("===============================================")
135 | from onnxsim import simplify
136 | onnx_model_fp32 = onnx.load(args.output_onnx_path)
137 | # onnx_model_fp16 = onnx.load(args.output_onnx_path + '_fp16')
138 | try:
139 | simplified_onnx_model_fp32, check_fp32 = simplify(onnx_model_fp32)
140 | # simplified_onnx_model_fp16, check_fp16 = simplify(onnx_model_fp16)
141 | except:
142 | print("The simplification failed. In this case, we suggest to try the command line version.")
143 | assert check_fp32, "Simplified fp32 ONNX model could not be validated"
144 | # assert check_fp16, "Simplified fp16 ONNX model could not be validated"
145 | onnx.save(simplified_onnx_model_fp32, args.output_onnx_path)
146 | # onnx.save(simplified_onnx_model_fp16, args.output_onnx_path + '_fp16')
147 |
148 | if True: #args.onnx_runtime_test:
149 | import onnxruntime as ort
150 | print("===============================================")
151 | print("|| Testing on ONNX Runtime ||")
152 | print("===============================================")
153 |
154 | ort_sess = ort.InferenceSession(args.output_onnx_path)
155 | input_dict_1 = {'enc_x': image_1.numpy(), 'sos_idx': np.array([sos_idx]),
156 | 'enc_mask': enc_mask.numpy(), 'fw_dec_mask': fw_dec_mask.numpy(),
157 | 'bw_dec_mask': bw_dec_mask.numpy(), 'cross_mask': atten_mask.numpy()}
158 | outputs_ort = ort_sess.run(None, input_dict_1)
159 | output_caption = tokens2description(outputs_ort[0][0].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx)
160 | print("ONNX Runtime result on 1st image:\n\t\t" + output_caption)
161 |
162 | input_dict_2 = copy.copy(input_dict_1)
163 | input_dict_2['enc_x'] = image_2.numpy()
164 | outputs_ort = ort_sess.run(None, input_dict_2)
165 | output_caption = tokens2description(outputs_ort[0][0].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx)
166 | print("ONNX Runtime result on 2nd image:\n\t\t" + output_caption)
167 | print("Done.", end="\n\n")
168 |
169 | if args.onnx_tensorrt_test:
170 | import onnx_tensorrt.backend as backend
171 | print("===============================================")
172 | print("|| Testing on ONNX-TensorRT backend ||")
173 | print("===============================================")
174 |
175 | engine = backend.prepare(onnx_model, device='CUDA:0', max_worker_size=args.max_worker_size)
176 |
177 | input_data = [image_1.numpy(), np.array([sos_idx]),
178 | enc_mask.numpy(), fw_dec_mask.numpy(), bw_dec_mask.numpy(), atten_mask.numpy()]
179 | output_data = engine.run(input_data)[0][0]
180 | output_caption = tokens2description(output_data.tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx)
181 | print("TensorRT result on 1st image:\n\t\t" + output_caption)
182 |
183 | input_data[0] = image_2.numpy()
184 | output_data = engine.run(input_data)[0][0]
185 | output_caption = tokens2description(output_data.tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx)
186 | print("TensorRT result on 2nd image:\n\t\t" + output_caption)
187 | print("Done.", end="\n\n")
188 |
189 | print("Closing.")
190 |
191 |
--------------------------------------------------------------------------------
/onnx4tensorrt/onnx2tensorrt.py:
--------------------------------------------------------------------------------
1 | # Script for ONNX 2 Tensorrt conversion
2 | # credits to: Shakhizat Nurgaliyev (https://github.com/shahizat)
3 |
4 | import os
5 | import torch
6 | import argparse
7 | import numpy as np
8 | import pickle
9 | import tensorrt as trt
10 | import pycuda.driver as cuda
11 | import pycuda.autoinit # this is important
12 |
13 | from utils.args_utils import str2type
14 | from utils.language_utils import tokens2description
15 | from utils.image_utils import preprocess_image
16 |
17 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import create_pad_mask, create_no_peak_and_pad_mask
18 | from onnx4tensorrt.End_ExpansionNet_v2_onnx_tensorrt import NUM_FEATURES, MAX_DECODE_STEPS
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser(description='ONNX 2 TensorRT')
23 | parser.add_argument('--onnx_path', type=str, default='./rf_model.onnx')
24 | parser.add_argument('--engine_path', type=str, default='./model_engine.trt')
25 | parser.add_argument('--data_type', type=str2type, default='fp32')
26 | args = parser.parse_args()
27 |
28 | with open('./demo_material/demo_coco_tokens.pickle', 'rb') as f:
29 | coco_tokens = pickle.load(f)
30 | sos_idx = coco_tokens['word2idx_dict'][coco_tokens['sos_str']]
31 | eos_idx = coco_tokens['word2idx_dict'][coco_tokens['eos_str']]
32 |
33 | # Build TensorRT engine
34 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
35 | trt_runtime = trt.Runtime(TRT_LOGGER)
36 |
37 | # The Onnx path is used for Onnx models.
38 | def build_engine_onnx(onnx_file_path, data_type):
39 | builder = trt.Builder(TRT_LOGGER)
40 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
41 | config = builder.create_builder_config()
42 | parser = trt.OnnxParser(network, TRT_LOGGER)
43 |
44 | builder.max_batch_size = 1
45 | config.max_workspace_size = 1 << 32 # sono in bytes quindi 2^32 -> 4GB
46 | # Load the Onnx model and parse it in order to populate the TensorRT network.
47 |
48 | if data_type == 'fp32':
49 | pass # do nothing, is the default
50 | elif data_type == 'fp16':
51 | if builder.platform_has_fast_fp16:
52 | config.set_flag(trt.BuilderFlag.FP16)
53 | print("fp16 is supported. Setting up fp16.")
54 | else:
55 | print("fp16 is not supported. Using the fp32 instead.")
56 | else:
57 | raise ValueError("Unsupported type. Only the following types are supported: " +
58 | "fp32, fp16, int8.")
59 |
60 | with open(onnx_file_path, "rb") as onnx_file:
61 | if not parser.parse(onnx_file.read()):
62 | print("ERROR: Failed to parse the ONNX file.")
63 | for error in range(parser.num_errors):
64 | print(parser.get_error(error))
65 | return None
66 |
67 | return builder.build_engine(network, config)
68 |
69 | img_size = 384
70 | image = preprocess_image('./demo_material/napoleon.jpg', img_size)
71 |
72 | # generate optimized graph
73 | file_already_exist = os.path.isfile(args.engine_path)
74 | if file_already_exist:
75 | print("Engine File:" + str(args.engine_path) + " already exists, loading engine instead of ONNX file.")
76 | with open(args.engine_path, "rb") as f:
77 | engine_data = f.read()
78 | engine = trt_runtime.deserialize_cuda_engine(engine_data)
79 | else:
80 | print("Building TensorRT engine from ONNX")
81 | if args.data_type == 'fp32':
82 | engine = build_engine_onnx(args.onnx_path, args.data_type)
83 | elif args.data_type == 'fp16':
84 | engine = build_engine_onnx(args.onnx_path + '_fp16', args.data_type)
85 | with open(args.engine_path, "wb") as f:
86 | f.write(engine.serialize())
87 | print("Engine written in: " + str(args.engine_path))
88 |
89 | print("Finished Building.")
90 | # engine = build_engine('./trt_fp.engine')
91 | context = engine.create_execution_context()
92 | print("Created execution context.")
93 |
94 | print("Testing first image on TensorRT")
95 | batch_size = 1
96 | inputs, outputs, bindings, stream = [], [], [], cuda.Stream()
97 | for binding in engine:
98 | size = trt.volume(engine.get_binding_shape(binding)) * batch_size
99 | dtype = trt.nptype(engine.get_binding_dtype(binding))
100 | # Allocate host and device buffers
101 | host_mem = cuda.pagelocked_empty(size, dtype)
102 | device_mem = cuda.mem_alloc(host_mem.nbytes)
103 | # Append the device buffer to device bindings.
104 | bindings.append(int(device_mem))
105 | # Append to the appropriate list.
106 | print("Binding type: " + str(dtype) + " is input: " + str(engine.binding_is_input(binding)))
107 | if engine.binding_is_input(binding):
108 | inputs.append({'host': host_mem, 'device': device_mem})
109 | else:
110 | outputs.append({'host': host_mem, 'device': device_mem})
111 |
112 | # Masking creation - - - - -
113 | enc_exp_list = [32, 64, 128, 256, 512]
114 | dec_exp = 16
115 | num_heads = 8
116 | enc_mask = create_pad_mask(mask_size=(batch_size, sum(enc_exp_list), NUM_FEATURES),
117 | pad_row=[0], pad_column=[0]).contiguous()
118 | no_peak_mask = create_no_peak_and_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS),
119 | num_pads=[0]).contiguous()
120 | cross_mask = create_pad_mask(mask_size=(batch_size, MAX_DECODE_STEPS, NUM_FEATURES),
121 | pad_row=[0], pad_column=[0]).contiguous()
122 | cross_mask = 1 - cross_mask
123 |
124 | fw_dec_mask = no_peak_mask.unsqueeze(2).expand(batch_size, MAX_DECODE_STEPS,
125 | dec_exp, MAX_DECODE_STEPS).contiguous(). \
126 | view(batch_size, MAX_DECODE_STEPS * dec_exp, MAX_DECODE_STEPS)
127 |
128 | bw_dec_mask = no_peak_mask.unsqueeze(-1).expand(batch_size,
129 | MAX_DECODE_STEPS, MAX_DECODE_STEPS,
130 | dec_exp).contiguous(). \
131 | view(batch_size, MAX_DECODE_STEPS, MAX_DECODE_STEPS * dec_exp)
132 |
133 | atten_mask = cross_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
134 |
135 | # - - - - - - - - - -
136 |
137 | # Set input values
138 | if args.data_type == 'fp32':
139 | inputs[0]['host'] = np.ravel(image).astype(np.float32)
140 | inputs[1]['host'] = np.array([sos_idx]).astype(np.int32)
141 | inputs[2]['host'] = np.array(enc_mask).astype(np.int32)
142 | inputs[3]['host'] = np.array(fw_dec_mask).astype(np.int32)
143 | inputs[4]['host'] = np.array(bw_dec_mask).astype(np.int32)
144 | inputs[5]['host'] = np.array(atten_mask).astype(np.int32)
145 | elif args.data_type == 'fp16':
146 | inputs[0]['host'] = np.ravel(image).astype(np.float16)
147 | inputs[1]['host'] = np.array([sos_idx]).astype(np.int32)
148 | inputs[2]['host'] = np.array(enc_mask).astype(np.int32)
149 | inputs[3]['host'] = np.array(fw_dec_mask).astype(np.float16)
150 | inputs[4]['host'] = np.array(bw_dec_mask).astype(np.float16)
151 | inputs[5]['host'] = np.array(atten_mask).astype(np.int32)
152 |
153 | # Transfer input data to the GPU.
154 | for inp in inputs:
155 | cuda.memcpy_htod_async(inp['device'], inp['host'], stream)
156 | # Execute model
157 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
158 | # Transfer predictions back from the GPU.
159 | for out in outputs:
160 | cuda.memcpy_dtoh_async(out['host'], out['device'], stream)
161 | # Synchronize the stream
162 | stream.synchronize()
163 | print(outputs[0]['host'].tolist())
164 | output_caption = tokens2description(outputs[0]['host'].tolist(), coco_tokens['idx2word_list'], sos_idx, eos_idx)
165 | output_probs = outputs[1]['host'].tolist()
166 | print(output_caption)
167 | print(output_probs)
168 |
--------------------------------------------------------------------------------
/optims/radam.py:
--------------------------------------------------------------------------------
1 | # Credits to Liyuan Liu
2 | # https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py
3 |
4 | import math
5 | import torch
6 | from torch.optim.optimizer import Optimizer, required
7 |
8 |
9 | class RAdam(Optimizer):
10 |
11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
12 | if not 0.0 <= lr:
13 | raise ValueError("Invalid learning rate: {}".format(lr))
14 | if not 0.0 <= eps:
15 | raise ValueError("Invalid epsilon value: {}".format(eps))
16 | if not 0.0 <= betas[0] < 1.0:
17 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
18 | if not 0.0 <= betas[1] < 1.0:
19 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
20 |
21 | self.degenerated_to_sgd = degenerated_to_sgd
22 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
23 | for param in params:
24 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
25 | param['buffer'] = [[None, None, None] for _ in range(10)]
26 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
27 | buffer=[[None, None, None] for _ in range(10)])
28 | super(RAdam, self).__init__(params, defaults)
29 |
30 | def __setstate__(self, state):
31 | super(RAdam, self).__setstate__(state)
32 |
33 | def step(self, closure=None):
34 |
35 | loss = None
36 | if closure is not None:
37 | loss = closure()
38 |
39 | for group in self.param_groups:
40 |
41 | for p in group['params']:
42 | if p.grad is None:
43 | continue
44 | grad = p.grad.data.float()
45 | if grad.is_sparse:
46 | raise RuntimeError('RAdam does not support sparse gradients')
47 |
48 | p_data_fp32 = p.data.float()
49 |
50 | state = self.state[p]
51 |
52 | if len(state) == 0:
53 | state['step'] = 0
54 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
55 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
56 | else:
57 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
58 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
59 |
60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
61 | beta1, beta2 = group['betas']
62 |
63 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad)
64 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad)
65 |
66 | state['step'] += 1
67 | buffered = group['buffer'][int(state['step'] % 10)]
68 | if state['step'] == buffered[0]:
69 | N_sma, step_size = buffered[1], buffered[2]
70 | else:
71 | buffered[0] = state['step']
72 | beta2_t = beta2 ** state['step']
73 | N_sma_max = 2 / (1 - beta2) - 1
74 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
75 | buffered[1] = N_sma
76 |
77 | # more conservative since it's an approximated value
78 | if N_sma >= 5:
79 | step_size = math.sqrt(
80 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
81 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
82 | elif self.degenerated_to_sgd:
83 | step_size = 1.0 / (1 - beta1 ** state['step'])
84 | else:
85 | step_size = -1
86 | buffered[2] = step_size
87 |
88 | # more conservative since it's an approximated value
89 | if N_sma >= 5:
90 | if group['weight_decay'] != 0:
91 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32)
92 | denom = exp_avg_sq.sqrt().add_(group['eps'])
93 | p_data_fp32.addcdiv_(value=-step_size * group['lr'], tensor1=exp_avg, tensor2=denom)
94 | p.data.copy_(p_data_fp32)
95 | elif step_size > 0:
96 | if group['weight_decay'] != 0:
97 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32)
98 | p_data_fp32.add_(alpha=-step_size * group['lr'], other=exp_avg)
99 | p.data.copy_(p_data_fp32)
100 |
101 | return loss
102 |
103 |
104 | class PlainRAdam(Optimizer):
105 |
106 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
107 | if not 0.0 <= lr:
108 | raise ValueError("Invalid learning rate: {}".format(lr))
109 | if not 0.0 <= eps:
110 | raise ValueError("Invalid epsilon value: {}".format(eps))
111 | if not 0.0 <= betas[0] < 1.0:
112 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
113 | if not 0.0 <= betas[1] < 1.0:
114 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
115 |
116 | self.degenerated_to_sgd = degenerated_to_sgd
117 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
118 |
119 | super(PlainRAdam, self).__init__(params, defaults)
120 |
121 | def __setstate__(self, state):
122 | super(PlainRAdam, self).__setstate__(state)
123 |
124 | def step(self, closure=None):
125 |
126 | loss = None
127 | if closure is not None:
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 |
132 | for p in group['params']:
133 | if p.grad is None:
134 | continue
135 | grad = p.grad.data.float()
136 | if grad.is_sparse:
137 | raise RuntimeError('RAdam does not support sparse gradients')
138 |
139 | p_data_fp32 = p.data.float()
140 |
141 | state = self.state[p]
142 |
143 | if len(state) == 0:
144 | state['step'] = 0
145 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
146 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
147 | else:
148 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
149 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
150 |
151 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
152 | beta1, beta2 = group['betas']
153 |
154 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad)
155 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad)
156 |
157 | state['step'] += 1
158 | beta2_t = beta2 ** state['step']
159 | N_sma_max = 2 / (1 - beta2) - 1
160 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
161 |
162 | # more conservative since it's an approximated value
163 | if N_sma >= 5:
164 | if group['weight_decay'] != 0:
165 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32)
166 | step_size = group['lr'] * math.sqrt(
167 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
168 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
169 | denom = exp_avg_sq.sqrt().add_(group['eps'])
170 | p_data_fp32.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom)
171 | p.data.copy_(p_data_fp32)
172 | elif self.degenerated_to_sgd:
173 | if group['weight_decay'] != 0:
174 | p_data_fp32.add_(alpha=-group['weight_decay'] * group['lr'], other=p_data_fp32)
175 | step_size = group['lr'] / (1 - beta1 ** state['step'])
176 | p_data_fp32.add_(alpha=-step_size, other=exp_avg)
177 | p.data.copy_(p_data_fp32)
178 |
179 | return loss
180 |
181 |
182 | class AdamW(Optimizer):
183 |
184 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
185 | if not 0.0 <= lr:
186 | raise ValueError("Invalid learning rate: {}".format(lr))
187 | if not 0.0 <= eps:
188 | raise ValueError("Invalid epsilon value: {}".format(eps))
189 | if not 0.0 <= betas[0] < 1.0:
190 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
191 | if not 0.0 <= betas[1] < 1.0:
192 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
193 |
194 | defaults = dict(lr=lr, betas=betas, eps=eps,
195 | weight_decay=weight_decay, warmup=warmup)
196 | super(AdamW, self).__init__(params, defaults)
197 |
198 | def __setstate__(self, state):
199 | super(AdamW, self).__setstate__(state)
200 |
201 | def step(self, closure=None):
202 | loss = None
203 | if closure is not None:
204 | loss = closure()
205 |
206 | for group in self.param_groups:
207 |
208 | for p in group['params']:
209 | if p.grad is None:
210 | continue
211 | grad = p.grad.data.float()
212 | if grad.is_sparse:
213 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
214 |
215 | p_data_fp32 = p.data.float()
216 |
217 | state = self.state[p]
218 |
219 | if len(state) == 0:
220 | state['step'] = 0
221 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
222 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
223 | else:
224 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
225 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
226 |
227 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
228 | beta1, beta2 = group['betas']
229 |
230 | state['step'] += 1
231 |
232 | exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad)
233 | exp_avg.mul_(beta1).add_(alpha=1 - beta1, other=grad)
234 |
235 | denom = exp_avg_sq.sqrt().add_(group['eps'])
236 | bias_correction1 = 1 - beta1 ** state['step']
237 | bias_correction2 = 1 - beta2 ** state['step']
238 |
239 | if group['warmup'] > state['step']:
240 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
241 | else:
242 | scheduled_lr = group['lr']
243 |
244 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
245 |
246 | if group['weight_decay'] != 0:
247 | p_data_fp32.add_(alpha=-group['weight_decay'] * scheduled_lr, other=p_data_fp32)
248 |
249 | p_data_fp32.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom)
250 |
251 | p.data.copy_(p_data_fp32)
252 |
253 | return loss
254 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # NOTE: Installing whatever version of torch, torchvision, h5py, Pillow
2 | # fit your machine should work in most cases. Here is reported one possible
3 | # working configuration but these specific versions are not mandatory.
4 |
5 | h5py==3.8.0
6 | numpy==1.21.6
7 | Pillow==9.0.1
8 | Pillow==10.2.0
9 | torch==1.12.1+cu113
10 | torchvision==0.10.0+cu111
11 |
--------------------------------------------------------------------------------
/requirements_wTensorRT.txt:
--------------------------------------------------------------------------------
1 | # NOTE: Installing whatever version of torch, torchvision, h5py, Pillow
2 | # fit your machine should work in most cases. Here is reported one possible
3 | # working configuration but these specific versions are not mandatory.
4 |
5 | h5py==3.8.0
6 | numpy==1.21.6
7 | onnx==1.14.0
8 | onnxruntime==1.14.1
9 | onnxsim==0.4.17
10 | Pillow==9.0.1
11 | Pillow==10.2.0
12 | pycuda==2022.1
13 | tensorrt==8.0.1.6
14 | torch==1.12.1+cu113
15 | torchvision==0.10.0+cu111
16 |
--------------------------------------------------------------------------------
/results_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jchenghu/ExpansionNet_v2/eb7f1c98d00cdaa61b60c43852155b3742fe5b45/results_image.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import math
4 | import torch
5 | import argparse
6 | from argparse import Namespace
7 | from utils.args_utils import str2list, str2bool
8 | import copy
9 | from time import time
10 |
11 | import torch.distributed as dist
12 | import torch.multiprocessing as mp
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 |
15 | from models.ensemble_captioning_model import EsembleCaptioningModel
16 | from data.coco_dataloader import CocoDataLoader
17 | from data.coco_dataset import CocoDatasetKarpathy
18 | from utils import language_utils
19 | from utils.language_utils import compute_num_pads as compute_num_pads
20 | from eval.eval import COCOEvalCap
21 |
22 | import warnings
23 | warnings.filterwarnings("ignore", category=UserWarning)
24 |
25 | import functools
26 | print = functools.partial(print, flush=True)
27 |
28 |
29 | def convert_time_as_hhmmss(ticks):
30 | return str(int(ticks / 60)) + " m " + \
31 | str(int(ticks) % 60) + " s"
32 |
33 |
34 | def compute_evaluation_loss(loss_function,
35 | model,
36 | data_set,
37 | data_loader,
38 | num_samples,
39 | sub_batch_size,
40 | dataset_split,
41 | rank=0,
42 | verbose=False):
43 | model.eval()
44 |
45 | sb_size = sub_batch_size
46 |
47 | tot_loss = 0
48 | num_sub_batch = math.ceil(num_samples / sb_size)
49 | tot_num_tokens = 0
50 | for sb_it in range(num_sub_batch):
51 | from_idx = sb_it * sb_size
52 | to_idx = min((sb_it + 1) * sb_size, num_samples)
53 |
54 | sub_batch_input_x, sub_batch_target_y, sub_batch_input_x_num_pads, sub_batch_target_y_num_pads, \
55 | = data_loader.get_batch_samples(img_idx_batch_list=list(range(from_idx, to_idx)),
56 | dataset_split=dataset_split)
57 | sub_batch_input_x = sub_batch_input_x.to(rank)
58 | sub_batch_target_y = sub_batch_target_y.to(rank)
59 |
60 | sub_batch_input_x = sub_batch_input_x
61 | sub_batch_target_y = sub_batch_target_y
62 | tot_num_tokens += sub_batch_target_y.size(1)*sub_batch_target_y.size(0) - \
63 | sum(sub_batch_target_y_num_pads)
64 | pred = model(enc_x=sub_batch_input_x,
65 | dec_x=sub_batch_target_y[:, :-1],
66 | enc_x_num_pads=sub_batch_input_x_num_pads,
67 | dec_x_num_pads=sub_batch_target_y_num_pads,
68 | apply_softmax=False)
69 | tot_loss += loss_function(pred, sub_batch_target_y[:, 1:],
70 | data_set.get_pad_token_idx(),
71 | divide_by_non_zeros=False).item()
72 | del sub_batch_input_x, sub_batch_target_y, pred
73 | torch.cuda.empty_cache()
74 | tot_loss /= tot_num_tokens
75 | if verbose and rank == 0:
76 | print("Validation Loss on " + str(num_samples) + " samples: " + str(tot_loss))
77 |
78 | return tot_loss
79 |
80 |
81 | def evaluate_model(ddp_model,
82 | y_idx2word_list,
83 | beam_size, max_seq_len,
84 | sos_idx, eos_idx,
85 | rank, ddp_sync_port,
86 | parallel_batches=16,
87 |
88 | indexes=[0],
89 | data_loader=None,
90 | dataset_split=CocoDatasetKarpathy.TrainSet_ID,
91 | use_images_instead_of_features=False,
92 |
93 | verbose=True,
94 | stanford_model_path="./eval/get_stanford_models.sh"):
95 |
96 | start_time = time()
97 |
98 | sub_list_predictions = []
99 | validate_y = []
100 | num_samples = len(indexes)
101 |
102 | ddp_model.eval()
103 | with torch.no_grad():
104 | sb_size = parallel_batches
105 | num_iter_sub_batches = math.ceil(len(indexes) / sb_size)
106 | for sb_it in range(num_iter_sub_batches):
107 | last_iter = sb_it == num_iter_sub_batches - 1
108 | if last_iter:
109 | from_idx = sb_it * sb_size
110 | to_idx = num_samples
111 | else:
112 | from_idx = sb_it * sb_size
113 | to_idx = (sb_it + 1) * sb_size
114 |
115 | if use_images_instead_of_features:
116 | sub_batch_x = [data_loader.get_images_by_idx(i, dataset_split=dataset_split, transf_mode='test').unsqueeze(0)
117 | for i in list(range(from_idx, to_idx))]
118 | sub_batch_x = torch.cat(sub_batch_x).to(rank)
119 | sub_batch_x_num_pads = [0] * sub_batch_x.size(0)
120 | else:
121 | sub_batch_x = [data_loader.get_vis_features_by_idx(i, dataset_split=dataset_split)
122 | for i in list(range(from_idx, to_idx))]
123 | sub_batch_x = torch.nn.utils.rnn.pad_sequence(sub_batch_x, batch_first=True).to(rank)
124 | sub_batch_x_num_pads = compute_num_pads(sub_batch_x)
125 |
126 | validate_y += [data_loader.get_all_image_captions_by_idx(i, dataset_split=dataset_split) \
127 | for i in list(range(from_idx, to_idx))]
128 |
129 | beam_search_kwargs = {'beam_size': beam_size,
130 | 'beam_max_seq_len': max_seq_len,
131 | 'sample_or_max': 'max',
132 | 'how_many_outputs': 1,
133 | 'sos_idx': sos_idx,
134 | 'eos_idx': eos_idx}
135 |
136 | output_words, _ = ddp_model(enc_x=sub_batch_x,
137 | enc_x_num_pads=sub_batch_x_num_pads,
138 | mode='beam_search', **beam_search_kwargs)
139 |
140 | output_words = [output_words[i][0] for i in range(len(output_words))]
141 |
142 | pred_sentence = language_utils.convert_allsentences_idx2word(output_words, y_idx2word_list)
143 | for sentence in pred_sentence:
144 | sub_list_predictions.append(' '.join(sentence[1:-1])) # remove EOS and SOS
145 |
146 | del sub_batch_x, sub_batch_x_num_pads, output_words
147 |
148 | ddp_model.train()
149 |
150 | if rank == 0 and verbose:
151 | # dirty code to leave the evaluation code untouched
152 | list_predictions = [sub_predictions for sub_predictions in sub_list_predictions]
153 | list_list_references = [[validate_y[i][j] for j in range(len(validate_y[i]))] for i in range(len(validate_y))]
154 |
155 | gts_dict = dict()
156 | for i in range(len(list_list_references)):
157 | gts_dict[i] = [{u'image_id': i, u'caption': list_list_references[i][j]}
158 | for j in range(len(list_list_references[i]))]
159 |
160 | pred_dict = dict()
161 | for i in range(len(list_predictions)):
162 | pred_dict[i] = [{u'image_id': i, u'caption': list_predictions[i]}]
163 |
164 | coco_eval = COCOEvalCap(gts_dict, pred_dict, list(range(len(list_predictions))),
165 | get_stanford_models_path=stanford_model_path)
166 | score_results = coco_eval.evaluate(bleu=True, rouge=True, cider=True, spice=True, meteor=True, verbose=False)
167 | elapsed_ticks = time() - start_time
168 | print("Evaluation Phase over " + str(len(validate_y)) + " BeamSize: " + str(beam_size) +
169 | " elapsed: " + str(int(elapsed_ticks/60)) + " m " + str(int(elapsed_ticks % 60)) + ' s')
170 | print(score_results)
171 |
172 | if rank == 0:
173 | return pred_dict, gts_dict
174 |
175 | return None, None
176 |
177 |
178 | def evaluate_model_on_set(ddp_model,
179 | caption_idx2word_list,
180 | sos_idx, eos_idx,
181 | num_samples,
182 | data_loader,
183 | dataset_split,
184 | eval_max_len,
185 | rank, ddp_sync_port,
186 | parallel_batches=16,
187 | beam_sizes=[1],
188 | stanford_model_path='./eval/get_stanford_models.sh',
189 | use_images_instead_of_features=False,
190 | get_predictions=False):
191 |
192 | with torch.no_grad():
193 | ddp_model.eval()
194 |
195 | for beam in beam_sizes:
196 | pred_dict, gts_dict = evaluate_model(ddp_model,
197 | y_idx2word_list=caption_idx2word_list,
198 | beam_size=beam, max_seq_len=eval_max_len,
199 | sos_idx=sos_idx, eos_idx=eos_idx,
200 | rank=rank,
201 | ddp_sync_port=ddp_sync_port,
202 | parallel_batches=parallel_batches,
203 | indexes=list(range(num_samples)),
204 | data_loader=data_loader,
205 | dataset_split=dataset_split,
206 | use_images_instead_of_features=use_images_instead_of_features,
207 | verbose=True,
208 | stanford_model_path=stanford_model_path)
209 |
210 | if rank == 0 and get_predictions:
211 | return pred_dict, gts_dict
212 |
213 | return None, None
214 |
215 |
216 | def get_ensemble_model(reference_model,
217 | checkpoints_paths,
218 | rank=0):
219 | model_list = []
220 | for i in range(len(checkpoints_paths)):
221 | model = copy.deepcopy(reference_model)
222 | model.to(rank)
223 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
224 | checkpoint = torch.load(checkpoints_paths[i],
225 | map_location=map_location)
226 | model.load_state_dict(checkpoint['model_state_dict'])
227 | model_list.append(model)
228 |
229 | model = EsembleCaptioningModel(model_list, rank).to(rank)
230 | ddp_model = DDP(model, device_ids=[rank])
231 | return ddp_model
232 |
233 |
234 | def test(rank, world_size,
235 | is_end_to_end,
236 | model_args,
237 | is_ensemble,
238 | coco_dataset,
239 | eval_parallel_batch_size,
240 | eval_beam_sizes,
241 | show_predictions,
242 | array_of_init_seeds,
243 | model_max_len,
244 | save_model_path,
245 | ddp_sync_port):
246 | print("GPU: " + str(rank) + "] Process " + str(rank) + " working...")
247 |
248 | os.environ['MASTER_ADDR'] = 'localhost'
249 | os.environ['MASTER_PORT'] = ddp_sync_port
250 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
251 |
252 | img_size = 384
253 | if is_end_to_end:
254 | from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
255 | model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
256 | swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
257 | swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
258 | swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.1,
259 | swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True,
260 | swin_use_checkpoint=False,
261 | final_swin_dim=1536,
262 |
263 | d_model=model_args.model_dim, N_enc=model_args.N_enc,
264 | N_dec=model_args.N_dec, num_heads=8, ff=2048,
265 | num_exp_enc_list=[32, 64, 128, 256, 512],
266 | num_exp_dec=16,
267 | output_word2idx=coco_dataset.caption_word2idx_dict,
268 | output_idx2word=coco_dataset.caption_idx2word_list,
269 | max_seq_len=model_max_len, drop_args=model_args.drop_args,
270 | rank=rank)
271 | else:
272 | from models.ExpansionNet_v2 import ExpansionNet_v2
273 | model = ExpansionNet_v2(d_model=model_args.model_dim, N_enc=model_args.N_enc,
274 | N_dec=model_args.N_dec, num_heads=8, ff=2048,
275 | num_exp_enc_list=[32, 64, 128, 256, 512],
276 | num_exp_dec=16,
277 | output_word2idx=coco_dataset.caption_word2idx_dict,
278 | output_idx2word=coco_dataset.caption_idx2word_list,
279 | max_seq_len=model_max_len, drop_args=model_args.drop_args,
280 | img_feature_dim=1536,
281 | rank=rank)
282 |
283 | model.to(rank)
284 | ddp_model = DDP(model, device_ids=[rank])
285 |
286 | data_loader = CocoDataLoader(coco_dataset=coco_dataset,
287 | batch_size=1,
288 | num_procs=world_size,
289 | array_of_init_seeds=array_of_init_seeds,
290 | dataloader_mode='image_wise',
291 | resize_image_size=img_size if is_end_to_end else None,
292 | rank=rank,
293 | verbose=False)
294 |
295 | if not is_ensemble:
296 | print("Not ensemble")
297 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
298 | checkpoint = torch.load(save_model_path, map_location=map_location)
299 | model.load_state_dict(checkpoint['model_state_dict'], strict=is_end_to_end)
300 | else:
301 | print("Ensembling Evaluation")
302 | list_checkpoints = os.listdir(save_model_path)
303 | checkpoints_list = [save_model_path + elem for elem in list_checkpoints if elem.endswith('.pth')]
304 | print("Detected checkpoints: " + str(checkpoints_list))
305 |
306 | if len(checkpoints_list) == 0:
307 | print("No checkpoints found")
308 | dist.destroy_process_group()
309 | exit(-1)
310 | ddp_model = get_ensemble_model(model, checkpoints_list, rank=rank)
311 |
312 | print("Evaluation on Validation Set")
313 | evaluate_model_on_set(ddp_model, coco_dataset.caption_idx2word_list,
314 | coco_dataset.get_sos_token_idx(), coco_dataset.get_eos_token_idx(),
315 | coco_dataset.val_num_images,
316 | data_loader,
317 | CocoDatasetKarpathy.ValidationSet_ID, model_max_len,
318 | rank, ddp_sync_port,
319 | parallel_batches=eval_parallel_batch_size,
320 | use_images_instead_of_features=is_end_to_end,
321 | beam_sizes=eval_beam_sizes)
322 |
323 | print("Evaluation on Test Set")
324 | pred_dict, gts_dict = evaluate_model_on_set(ddp_model, coco_dataset.caption_idx2word_list,
325 | coco_dataset.get_sos_token_idx(), coco_dataset.get_eos_token_idx(),
326 | coco_dataset.test_num_images,
327 | data_loader,
328 | CocoDatasetKarpathy.TestSet_ID, model_max_len,
329 | rank, ddp_sync_port,
330 | parallel_batches=eval_parallel_batch_size,
331 | use_images_instead_of_features=is_end_to_end,
332 | beam_sizes=eval_beam_sizes,
333 | get_predictions=show_predictions)
334 |
335 | if rank == 0 and show_predictions:
336 | with open("predictions.txt", 'w') as f:
337 | for i in range(len(pred_dict)):
338 | prediction = pred_dict[i][0]['caption']
339 | ground_truth_list = [gts_dict[i][j]['caption'] for j in range(len(gts_dict[i]))]
340 | f.write(str(i) + '----------------------------------------------------------------------' + '\n')
341 | f.write('Pred: ' + str(prediction) + '\n')
342 | f.write('Gt: ' + str(ground_truth_list) + '\n')
343 |
344 | print("[GPU: " + str(rank) + " ] Closing...")
345 | dist.destroy_process_group()
346 |
347 |
348 | def spawn_train_processes(is_end_to_end,
349 | model_args,
350 | is_ensemble,
351 | coco_dataset,
352 | eval_parallel_batch_size,
353 | eval_beam_sizes,
354 | show_predictions,
355 | num_gpus,
356 | ddp_sync_port,
357 | save_model_path
358 | ):
359 |
360 | max_sequence_length = coco_dataset.max_seq_len + 20
361 | print("Max sequence length: " + str(max_sequence_length))
362 | print("y vocabulary size: " + str(len(coco_dataset.caption_word2idx_dict)))
363 |
364 | world_size = torch.cuda.device_count()
365 | print("Using - ", world_size, " processes / GPUs!")
366 | assert(num_gpus <= world_size), "requested num gpus higher than the number of available gpus "
367 | print("Requested num GPUs: " + str(num_gpus))
368 |
369 | array_of_init_seeds = [random.random() for _ in range(10)]
370 | mp.spawn(test,
371 | args=(num_gpus,
372 | is_end_to_end,
373 | model_args,
374 | is_ensemble,
375 | coco_dataset,
376 | eval_parallel_batch_size,
377 | eval_beam_sizes,
378 | show_predictions,
379 | array_of_init_seeds,
380 | max_sequence_length,
381 | save_model_path,
382 | ddp_sync_port),
383 | nprocs=num_gpus,
384 | join=True)
385 |
386 |
387 | if __name__ == "__main__":
388 |
389 | parser = argparse.ArgumentParser(description='Image Captioning')
390 | parser.add_argument('--model_dim', type=int, default=512)
391 | parser.add_argument('--N_enc', type=int, default=3)
392 | parser.add_argument('--N_dec', type=int, default=3)
393 | parser.add_argument('--show_predictions', type=str2bool, default=False)
394 |
395 | parser.add_argument('--is_end_to_end', type=str2bool, default=True)
396 | parser.add_argument('--is_ensemble', type=str2bool, default=False)
397 |
398 | parser.add_argument('--num_gpus', type=int, default=1)
399 | parser.add_argument('--ddp_sync_port', type=int, default=12354)
400 | parser.add_argument('--save_model_path', type=str, default='./github_ignore_material/saves/')
401 |
402 | parser.add_argument('--eval_parallel_batch_size', type=int, default=16)
403 | parser.add_argument('--eval_beam_sizes', type=str2list, default=[3])
404 |
405 | parser.add_argument('--images_path', type=str, default="./github_ignore_material/raw_data/")
406 | parser.add_argument('--preproc_images_hdf5_filepath', type=str, default=None)
407 | parser.add_argument('--features_path', type=str, default='./github_ignore_material/raw_data/')
408 | parser.add_argument('--captions_path', type=str, default='./github_ignore_material/raw_data/')
409 |
410 | args = parser.parse_args()
411 | args.ddp_sync_port = str(args.ddp_sync_port)
412 |
413 | assert (args.eval_parallel_batch_size % args.num_gpus == 0), \
414 | "num gpus must be multiple of the requested parallel batch size"
415 |
416 | print("is_ensemble: " + str(args.is_ensemble))
417 | print("eval parallel batch_size: " + str(args.eval_parallel_batch_size))
418 | print("ddp_sync_port: " + str(args.ddp_sync_port))
419 | print("save_model_path: " + str(args.save_model_path))
420 |
421 | drop_args = Namespace(enc=0.0,
422 | dec=0.0,
423 | enc_input=0.0,
424 | dec_input=0.0,
425 | other=0.0)
426 |
427 | model_args = Namespace(model_dim=args.model_dim,
428 | N_enc=args.N_enc,
429 | N_dec=args.N_dec,
430 | dropout=0.0,
431 | drop_args=drop_args
432 | )
433 |
434 | coco_dataset = CocoDatasetKarpathy(
435 | images_path=args.images_path,
436 | coco_annotations_path=args.captions_path + "dataset_coco.json",
437 | preproc_images_hdf5_filepath=args.preproc_images_hdf5_filepath if args.is_end_to_end else None,
438 | precalc_features_hdf5_filepath=None if args.is_end_to_end else args.features_path,
439 | limited_num_train_images=None,
440 | limited_num_val_images=5000)
441 |
442 | spawn_train_processes(is_end_to_end=args.is_end_to_end,
443 | model_args=model_args,
444 | is_ensemble=args.is_ensemble,
445 | coco_dataset=coco_dataset,
446 | eval_parallel_batch_size=args.eval_parallel_batch_size,
447 | eval_beam_sizes=args.eval_beam_sizes,
448 | show_predictions=args.show_predictions,
449 | num_gpus=args.num_gpus,
450 | ddp_sync_port=args.ddp_sync_port,
451 | save_model_path=args.save_model_path
452 | )
453 |
--------------------------------------------------------------------------------
/utils/args_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | # thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
5 | def str2bool(v):
6 | if isinstance(v, bool):
7 | return v
8 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
9 | return True
10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
11 | return False
12 | else:
13 | raise argparse.ArgumentTypeError('Boolean value expected.')
14 |
15 |
16 | def str2list(v):
17 | if '[' in v and ']' in v:
18 | return list(map(int, v.strip('[]').split(',')))
19 | else:
20 | raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]')
21 |
22 |
23 | def str2type(v):
24 | if v.lower() == 'fp32' or v.lower() == 'fp16':
25 | return v.lower()
26 | else:
27 | raise argparse.ArgumentTypeError('Invalid type, currently supported type: [fp32, fp16]')
28 |
29 |
30 |
31 | def scheduler_type_choice(v):
32 | if v == 'annealing' or v == 'custom_warmup_anneal':
33 | return v
34 | else:
35 | raise argparse.ArgumentTypeError('Argument must be either '
36 | '\'annealing\', '
37 | '\'custom_warmup_anneal\'')
38 |
39 |
40 | def optim_type_choice(v):
41 | if v == 'adam' or v == 'radam':
42 | return v
43 | else:
44 | raise argparse.ArgumentTypeError('Argument must be either \'adam\', '
45 | '\'radam\'.')
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torchvision
3 | from PIL import Image as PIL_Image
4 |
5 |
6 | def preprocess_image(image_path, img_size):
7 | transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size))])
8 | transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
9 | std=[0.229, 0.224, 0.225])])
10 |
11 | pil_image = PIL_Image.open(image_path)
12 | if pil_image.mode != 'RGB':
13 | pil_image = PIL_Image.new("RGB", pil_image.size)
14 | preprocess_pil_image = transf_1(pil_image)
15 | image = torchvision.transforms.ToTensor()(preprocess_pil_image)
16 | image = transf_2(image)
17 | return image.unsqueeze(0)
18 |
--------------------------------------------------------------------------------
/utils/language_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def compute_num_pads(list_vis_features):
5 | max_len = -1
6 | for vis_features in list_vis_features:
7 | num_vis_features = len(vis_features)
8 | if num_vis_features > max_len:
9 | max_len = num_vis_features
10 | num_pad_vector = []
11 | for vis_features in list_vis_features:
12 | num_pad_vector.append(max_len - len(vis_features))
13 | return num_pad_vector
14 |
15 |
16 | def remove_punctuations(sentences):
17 | punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"]
18 | res_sentences_list = []
19 | for i in range(len(sentences)):
20 | res_sentence = []
21 | for word in sentences[i].split(' '):
22 | if word not in punctuations:
23 | res_sentence.append(word)
24 | res_sentences_list.append(' '.join(res_sentence))
25 | return res_sentences_list
26 |
27 |
28 | def lowercase_and_clean_trailing_spaces(sentences):
29 | return [(sentences[i].lower()).rstrip() for i in range(len(sentences))]
30 |
31 |
32 | def add_space_between_non_alphanumeric_symbols(sentences):
33 | return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))]
34 |
35 |
36 | def tokenize(list_sentences):
37 | res_sentences_list = []
38 | for i in range(len(list_sentences)):
39 | sentence = list_sentences[i].split(' ')
40 | while '' in sentence:
41 | sentence.remove('')
42 | res_sentences_list.append(sentence)
43 | return res_sentences_list
44 |
45 |
46 | def convert_vector_word2idx(sentence, word2idx_dict):
47 | return [word2idx_dict[word] for word in sentence]
48 |
49 |
50 | def convert_allsentences_word2idx(sentences, word2idx_dict):
51 | return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))]
52 |
53 |
54 | def convert_vector_idx2word(sentence, idx2word_list):
55 | return [idx2word_list[idx] for idx in sentence]
56 |
57 |
58 | def convert_allsentences_idx2word(sentences, idx2word_list):
59 | return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))]
60 |
61 |
62 | def tokens2description(tokens, idx2word_list, sos_idx, eos_idx):
63 | desc = []
64 | for tok in tokens:
65 | if tok == sos_idx:
66 | continue
67 | if tok == eos_idx:
68 | break
69 | desc.append(tok)
70 | desc = convert_vector_idx2word(desc, idx2word_list)
71 | desc[-1] = desc[-1] + '.'
72 | pred = ' '.join(desc).capitalize()
73 | return pred
74 |
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def create_pad_mask(mask_size, pad_row, pad_column, rank=0):
5 | batch_size, output_seq_len, input_seq_len = mask_size
6 | mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank)
7 |
8 | for batch_idx in range(batch_size):
9 | mask[batch_idx, :, (input_seq_len - pad_column[batch_idx]):] = 0
10 | mask[batch_idx, (output_seq_len - pad_row[batch_idx]):, :] = 0
11 | return mask
12 |
13 |
14 | def create_no_peak_and_pad_mask(mask_size, num_pads, rank=0):
15 | batch_size, seq_len, seq_len = mask_size
16 | mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8),
17 | diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank)
18 | for batch_idx in range(batch_size):
19 | mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0
20 | mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0
21 | return mask
22 |
--------------------------------------------------------------------------------
/utils/saving_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | from datetime import datetime
5 |
6 | from torch.nn.parameter import Parameter
7 |
8 | def load_most_recent_checkpoint(model,
9 | optimizer=None,
10 | scheduler=None,
11 | data_loader=None,
12 | rank=0,
13 | save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S',
14 | verbose=True):
15 | ls_files = os.listdir(save_model_path)
16 | most_recent_checkpoint_datetime = None
17 | most_recent_checkpoint_filename = None
18 | most_recent_checkpoint_info = 'no_additional_info'
19 | for file_name in ls_files:
20 | if file_name.startswith('checkpoint_'):
21 | _, datetime_str, _, info, _ = file_name.split('_')
22 | file_datetime = datetime.strptime(datetime_str, datetime_format)
23 | if (most_recent_checkpoint_datetime is None) or \
24 | (most_recent_checkpoint_datetime is not None and
25 | file_datetime > most_recent_checkpoint_datetime):
26 | most_recent_checkpoint_datetime = file_datetime
27 | most_recent_checkpoint_filename = file_name
28 | most_recent_checkpoint_info = info
29 |
30 | if most_recent_checkpoint_filename is not None:
31 | if verbose:
32 | print("Loading: " + str(save_model_path + most_recent_checkpoint_filename))
33 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
34 | checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename,
35 | map_location=map_location)
36 | model.load_state_dict(checkpoint['model_state_dict'])
37 | if optimizer is not None:
38 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
39 | if scheduler is not None:
40 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
41 | if data_loader is not None:
42 | data_loader.load_state(checkpoint['data_loader_state_dict'])
43 | return True, most_recent_checkpoint_info
44 | else:
45 | if verbose:
46 | print("Loading: no checkpoint found in " + str(save_model_path))
47 | return False, most_recent_checkpoint_info
48 |
49 |
50 | def save_last_checkpoint(model,
51 | optimizer,
52 | scheduler,
53 | data_loader,
54 | save_model_path='./',
55 | num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S',
56 | additional_info='noinfo',
57 | verbose=True):
58 |
59 | checkpoint = {
60 | 'model_state_dict': model.state_dict(),
61 | 'optimizer_state_dict': optimizer.state_dict(),
62 | 'scheduler_state_dict': scheduler.state_dict(),
63 | 'data_loader_state_dict': data_loader.save_state(),
64 | }
65 |
66 | ls_files = os.listdir(save_model_path)
67 | oldest_checkpoint_datetime = None
68 | oldest_checkpoint_filename = None
69 | num_check_points = 0
70 | for file_name in ls_files:
71 | if file_name.startswith('checkpoint_'):
72 | num_check_points += 1
73 | _, datetime_str, _, _, _ = file_name.split('_')
74 | file_datetime = datetime.strptime(datetime_str, datetime_format)
75 | if (oldest_checkpoint_datetime is None) or \
76 | (oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime):
77 | oldest_checkpoint_datetime = file_datetime
78 | oldest_checkpoint_filename = file_name
79 |
80 | if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints:
81 | os.remove(save_model_path + oldest_checkpoint_filename)
82 |
83 | new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \
84 | '_epoch' + str(data_loader.get_epoch_it()) + \
85 | 'it' + str(data_loader.get_batch_it()) + \
86 | 'bs' + str(data_loader.get_batch_size()) + \
87 | '_' + str(additional_info) + '_.pth'
88 | if verbose:
89 | print("Saved to " + str(new_checkpoint_filename))
90 | torch.save(checkpoint, save_model_path + new_checkpoint_filename)
91 |
92 |
93 | def partially_load_state_dict(model, state_dict, verbose=False, max_num_print=5):
94 | own_state = model.state_dict()
95 | max_num_print = max_num_print
96 | count_print = 0
97 | for name, param in state_dict.items():
98 | if name not in own_state:
99 | if verbose:
100 | print("Not found: " + str(name))
101 | continue
102 | if isinstance(param, Parameter):
103 | param = param.data
104 | own_state[name].copy_(param)
105 | if verbose:
106 | if count_print < max_num_print:
107 | print("Found: " + str(name))
108 | count_print += 1
109 |
110 |
--------------------------------------------------------------------------------