├── .gitignore ├── .gitmodules ├── README.md ├── data └── coco_caption │ └── captions_val2014.json ├── flm ├── __init__.py ├── config.py ├── datamodules │ ├── __init__.py │ ├── coco_caption_karpathy_datamodule.py │ ├── conceptual_caption12m_datamodule.py │ ├── conceptual_caption_datamodule.py │ ├── datamodule_base.py │ ├── f30k_caption_karpathy_datamodule.py │ ├── laion100m_datamodule.py │ ├── laion_datamodule.py │ ├── multitask_datamodule.py │ ├── nlvr2_datamodule.py │ ├── sbu_datamodule.py │ ├── snli_datamodule.py │ ├── vg_caption_datamodule.py │ └── vqav2_datamodule.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_caption_karpathy_dataset.py │ ├── conceptual_caption12m_dataset.py │ ├── conceptual_caption_dataset.py │ ├── f30k_caption_karpathy_dataset.py │ ├── laion100m_dataset.py │ ├── laion_dataset.py │ ├── nlvr2_dataset.py │ ├── sbu_caption_dataset.py │ ├── snli_dataset.py │ ├── vg_caption_dataset.py │ └── vqav2_dataset.py ├── gadgets │ ├── __init__.py │ └── my_metrics.py ├── modules │ ├── __init__.py │ ├── bert_model.py │ ├── clip_model.py │ ├── dist_utils.py │ ├── flm_module.py │ ├── flm_tools.py │ ├── heads.py │ ├── meter_utils.py │ └── objectives.py ├── transforms │ ├── __init__.py │ ├── randaug.py │ ├── transform.py │ └── utils.py └── utils │ ├── __init__.py │ ├── find_newest_ckpt.py │ ├── glossary.py │ ├── utils.py │ ├── whole_word_masking.py │ ├── write_coco_karpathy.py │ ├── write_conceptual_caption.py │ ├── write_conceptual_caption12M_cloud.py │ ├── write_conceptual_caption_cloud.py │ ├── write_f30k_karpathy.py │ ├── write_nlvr2.py │ ├── write_sbu.py │ ├── write_snli.py │ ├── write_vg.py │ ├── write_vqa.py │ └── write_winoground.py ├── imgs ├── LMs.png └── pipeline.png ├── requirements.txt └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "flm/pycocoevalcap"] 2 | path = flm/pycocoevalcap 3 | url = https://github.com/salaniz/pycocoevalcap.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLM 2 | Official code for "Accelerating Vision-Language Pretraining with Free Language Modeling" (CVPR 2023) 3 | 4 | Paper: https://arxiv.org/abs/2303.14038 5 | 6 | 7 | ## Introduction 8 | 9 | 10 | ![](imgs/LMs.png) 11 | The state of the arts in vision-language pretraining (VLP) achieves exemplary performance but suffers from high training costs resulting from slow convergence and long training time, especially on large-scale web datasets. An essential obstacle to training efficiency lies in the entangled prediction rate (percentage of tokens for reconstruction) and corruption rate (percentage of corrupted tokens) in masked language modeling (MLM), that is, a proper corruption rate is achieved at the cost of a large portion of output tokens being excluded from prediction loss. 12 | 13 | Free language modeling (FLM) is a new language modeling method that enables a 100% prediction rate with arbitrary corruption rates. FLM successfully frees the prediction rate from the tie-up with the corruption rate while allowing the corruption spans to be customized for each token to be predicted. FLM-trained models are encouraged to learn better and faster given the same GPU time by exploiting bidirectional contexts more flexibly. 14 | 15 |

16 | 17 |

18 | 19 | ## Install 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | ## Dataset Preparation 24 | We follow [ViLT](https://github.com/dandelin/ViLT) and use `pyarrow` to serialize the datasets. See this [link](https://github.com/dandelin/ViLT/blob/master/DATA.md) for details. 25 | 26 | ## Pretraining 27 | ```bash 28 | export MASTER_ADDR=$DIST_0_IP 29 | export MASTER_PORT=$DIST_0_PORT 30 | export NODE_RANK=$DIST_RANK 31 | 32 | python run.py with data_root= exp_name="pretrain_FLM_4m" \ 33 | num_gpus=8 resume_from=None fix_exp_version=True \ 34 | flm text_roberta image_size=288 clip32 causal_flm \ 35 | precision=16 max_steps=30000 learning_rate=0.00008 \ 36 | batch_size=4096 per_gpu_batchsize=64 warmup_steps=0.05 37 | ``` 38 | #### Pretrained Checkpoints 39 | FLM-CLIP32-RoBERTa (resolution: 288^2) pre-trained on GCC+SBU+COCO+VG [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4m.ckpt) 40 | 41 | FLM-CLIP32-RoBERTa fintuned on VQAv2 (resolution: 576^2) [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4M_ft_vqa.ckpt) 42 | 43 | FLM-CLIP32-RoBERTa fintuned on NLVR2 (resolution: 288^2) [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4m_ft_nlvr2.ckpt) 44 | 45 | ## Evaluation on Downstream Tasks 46 | #### Visual Question Answering (VQA v2) 47 | ```bash 48 | # training: 4 gpu 49 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_vqa_train" \ 50 | num_gpus=4 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \ 51 | ft_vqa text_roberta image_size=576 clip32 causal_flm \ 52 | learning_rate=0.000005 batch_size=512 per_gpu_batchsize=32 log_dir='result_ft' clip_randaug 53 | 54 | # testing: 4 gpu 55 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_vqa_test" \ 56 | num_gpus=4 load_path="pretrain_4M_ft_vqa.ckpt" \ 57 | ft_vqa text_roberta image_size=576 clip32 causal_flm \ 58 | per_gpu_batchsize=32 log_dir='result_ft' test_only=True skip_test_step=True 59 | ``` 60 | 61 | #### Natural Language for Visual Reasoning 62 | ```bash 63 | # training: 1 gpu 64 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_nlvr2_train" \ 65 | num_gpus=1 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \ 66 | ft_nlvr2 text_roberta image_size=288 clip32 causal_flm \ 67 | learning_rate=0.00001 batch_size=256 per_gpu_batchsize=32 log_dir='result_ft' clip_randaug 68 | 69 | # testing: 1 gpu 70 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_nlvr2_test" \ 71 | num_gpus=1 load_path="pretrain_4M_ft_nlvr2.ckpt" \ 72 | ft_nlvr2 text_roberta image_size=288 clip32 causal_flm \ 73 | per_gpu_batchsize=32 log_dir='result_ft' test_only=True skip_test_step=True 74 | ``` 75 | 76 | #### Image Captioning 77 | ```bash 78 | # training: 4 gpu 79 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_cap_coco_train" \ 80 | num_gpus=4 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \ 81 | ft_cap_coco text_roberta image_size=288 clip32 causal_flm \ 82 | learning_rate=0.000003 batch_size=256 per_gpu_batchsize=64 log_dir='result_ft' clip_randaug 83 | 84 | # testing: 4 gpu 85 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_cap_coco_test" \ 86 | num_gpus=4 load_path="pretrain_4M_ft_cap.ckpt" \ 87 | ft_cap_coco text_roberta image_size=384 clip32 causal_flm \ 88 | per_gpu_batchsize=64 log_dir='result_ft' test_only=True skip_test_step=True 89 | ``` 90 | 91 | #### Image-Text Retrieval 92 | 93 | ```bash 94 | # training: 8 gpu 95 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_irtr_f30k_train" \ 96 | num_gpus=8 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \ 97 | ft_irtr_f30k text_roberta image_size=384 clip32 causal_flm precision=16 \ 98 | learning_rate=0.000005 batch_size=512 per_gpu_batchsize=8 log_dir='result_ft' clip_randaug 99 | 100 | # testing: 8 gpu 101 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_irtr_f30k_test" \ 102 | num_gpus=8 load_path="pretrain_4M_ft_irtr_f30k.ckpt" \ 103 | ft_irtr_f30k text_roberta image_size=384 clip32 causal_flm \ 104 | per_gpu_batchsize=8 log_dir='result_ft' test_only=True skip_test_step=True 105 | ``` 106 | 107 | 108 | ## Citation 109 | ``` 110 | @misc{wang2023accelerating, 111 | title={Accelerating Vision-Language Pretraining with Free Language Modeling}, 112 | author={Teng Wang and Yixiao Ge and Feng Zheng and Ran Cheng and Ying Shan and Xiaohu Qie and Ping Luo}, 113 | year={2023}, 114 | eprint={2303.14038}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.CV} 117 | } 118 | ``` 119 | 120 | ## Acknowledgements 121 | The code is highly based on [METER](https://github.com/zdou0830/METER) and [ViLT](https://github.com/dandelin/ViLT). -------------------------------------------------------------------------------- /flm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/__init__.py -------------------------------------------------------------------------------- /flm/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("FLM") 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "itm": 0, 9 | "mlm": 0, # used for pretraining MLM-based models 10 | "ar": 0, # used for pretraining AR-based models or finetuning on captioning tasks 11 | "flm": 0, # used for pretraining FLM-based models 12 | "vqa": 0, 13 | "nlvr2": 0, 14 | "irtr": 0, 15 | } 16 | ret.update(d) 17 | return ret 18 | 19 | 20 | @ex.config 21 | def config(): 22 | only_use_cls_for_flm = False 23 | 24 | debug = False 25 | log_path = "" 26 | is_causal_mask = False 27 | 28 | causal_mask_w_post_cls = False 29 | get_caption_metric = False 30 | get_mlm_caption_metric = False 31 | get_cl_recall_metric = False 32 | get_cl_itm_recall_metric = False 33 | 34 | skip_test_step = False 35 | 36 | flm_backbone = False 37 | temperature = 0.05 38 | random_flm_mask = False 39 | disable_flm_shuffle = False 40 | flm_mask_prob = 0. 41 | text_encoder_from_scratch = False 42 | full_att_mask_for_eval = False 43 | full_att_mask = False 44 | enable_flm_aux_lm_loss = False 45 | flm_aux_lm_loss_l2r_weight = 1.0 46 | flm_aux_lm_loss_r2l_weight = 1.0 47 | 48 | span_corruption_rate = 0 49 | 50 | share_lm_scorer_weights = True 51 | 52 | max_dataset_len = -1 53 | 54 | hidden_size_for_fusion = 768 55 | 56 | caption_prompt = None 57 | add_new_bos_token = False 58 | prepend_bos_token = False 59 | append_eos_token = False 60 | 61 | # webdataset 62 | allow_val_webdataset = False 63 | 64 | # adaptive top bottom layer number for flm 65 | num_reconstructor_bottom_layer = 6 66 | num_reconstructor_top_layer = 6 67 | num_bottom_layer = 6 68 | 69 | # enable_prefix_LM=False 70 | prefix_lm_alpha = 1.0 71 | flm_prediction_rate = 1.0 72 | 73 | # exp name 74 | exp_name = "flm" 75 | seed = 2022 76 | datasets = ["coco", "vg", "sbu", "gcc"] 77 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 78 | # hloss_weights = _hloss_weights({'lmcl': 0.1}) 79 | # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 80 | batch_size = 4096 81 | 82 | prepare_data_per_node = True 83 | # Image setting 84 | train_transform_keys = ["clip"] 85 | val_transform_keys = ["clip"] 86 | image_size = 224 87 | patch_size = 32 88 | draw_false_image = 1 89 | image_only = False 90 | resolution_before = 224 91 | 92 | # Text Setting 93 | vqav2_label_size = 3129 94 | max_text_len = 50 95 | tokenizer = ".cache/bert-base-uncased" 96 | vocab_size = 30522 97 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 98 | mlm_prob = 0.15 99 | draw_false_text = 0 100 | 101 | # Transformer Setting 102 | num_top_layer = 6 103 | input_image_embed_size = 768 104 | input_text_embed_size = 768 105 | vit = 'ViT-B/32' 106 | hidden_size = 768 107 | num_heads = 12 108 | num_heads_fusion = 12 109 | num_layers = 6 110 | mlp_ratio = 4 111 | drop_rate = 0.1 112 | # truncate_bottom_text_encoder_layer = False 113 | 114 | # Optimizer Setting 115 | optim_type = "adamw" 116 | learning_rate = 1e-5 117 | weight_decay = 0.01 118 | decay_power = 1 119 | max_epoch = 100 120 | max_steps = 100000 121 | warmup_steps = 10000 122 | end_lr = 0 123 | lr_mult_head = 5 # multiply lr for downstream heads 124 | lr_mult_cross_modal = 5 # multiply lr for the cross-modal module 125 | 126 | # Downstream Setting 127 | get_recall_metric = False 128 | 129 | # PL Trainer Setting 130 | resume_from = None 131 | fast_dev_run = False 132 | val_check_interval = 0.2 133 | num_sanity_val_steps = 2 134 | test_only = False 135 | ckpt_save_top_k = 1 136 | 137 | # below params varies with the environment 138 | data_root = "" 139 | log_dir = "result" 140 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 141 | num_gpus = 8 142 | # num_nodes = 1 143 | load_path = "" 144 | fix_exp_version = False 145 | num_workers = 8 146 | precision = 32 147 | 148 | 149 | @ex.named_config 150 | def causal_flm(): 151 | is_causal_mask = True 152 | causal_mask_w_post_cls = True 153 | flm_backbone = True 154 | 155 | 156 | @ex.named_config 157 | def causal_lm(): 158 | is_causal_mask = True 159 | causal_mask_w_post_cls = True 160 | 161 | 162 | @ex.named_config 163 | def mlm(): 164 | exp_name = "mlm" 165 | # datasets = ["gcc"] 166 | loss_names = _loss_names({"mlm": 1}) 167 | batch_size = 4096 168 | max_epoch = 10 169 | max_steps = 100000 170 | warmup_steps = 0.1 171 | whole_word_masking = True 172 | 173 | 174 | @ex.named_config 175 | def ar(): 176 | exp_name = "ar" 177 | # datasets = ["gcc"] 178 | loss_names = _loss_names({"ar": 1}) 179 | batch_size = 4096 180 | max_epoch = 10 181 | max_steps = 100000 182 | warmup_steps = 0.1 183 | whole_word_masking = True 184 | 185 | 186 | @ex.named_config 187 | def flm(): 188 | exp_name = "flm" 189 | # datasets = ["gcc"] 190 | loss_names = _loss_names({"flm": 1}) 191 | batch_size = 4096 192 | max_epoch = 10 193 | max_steps = 100000 194 | warmup_steps = 0.1 195 | whole_word_masking = True 196 | 197 | is_causal_mask = True 198 | causal_mask_w_post_cls = True 199 | # disable_cross_modal_image_layer=True 200 | # cross_modal_layer='text_only' 201 | flm_backbone = True 202 | enable_flm_aux_lm_loss = True 203 | 204 | 205 | @ex.named_config 206 | def flm_itm(): 207 | exp_name = "flm_itm" 208 | # datasets = ["gcc"] 209 | loss_names = _loss_names({"flm": 1, "itm": 1}) 210 | batch_size = 4096 211 | max_epoch = 10 212 | max_steps = 100000 213 | warmup_steps = 0.1 214 | whole_word_masking = True 215 | enable_flm_aux_lm_loss = True 216 | 217 | 218 | @ex.named_config 219 | def ft_nlvr2(): 220 | exp_name = "finetune_nlvr2" 221 | datasets = ["nlvr2"] 222 | loss_names = _loss_names({"nlvr2": 1}) 223 | batch_size = 256 224 | max_epoch = 10 225 | max_steps = None 226 | warmup_steps = 0.1 227 | draw_false_image = 0 228 | learning_rate = 1e-5 229 | lr_mult_head = 10 230 | lr_mult_cross_modal = 5 231 | tokenizer = ".cache/bert-base-uncased" 232 | max_text_len = 50 233 | input_text_embed_size = 768 234 | vit = 'ViT-B/32' 235 | train_transform_keys = ["clip"] 236 | val_transform_keys = ["clip"] 237 | input_image_embed_size = 768 238 | image_size = 288 239 | 240 | 241 | @ex.named_config 242 | def ft_vqa(): 243 | exp_name = "finetune_vqa" 244 | datasets = ["vqa"] 245 | loss_names = _loss_names({"vqa": 1}) 246 | batch_size = 512 247 | max_epoch = 10 248 | max_steps = None 249 | warmup_steps = 0.1 250 | draw_false_image = 0 251 | learning_rate = 5e-6 252 | val_check_interval = 0.5 253 | lr_mult_head = 50 254 | lr_mult_cross_modal = 5 255 | tokenizer = ".cache/bert-base-uncased" 256 | max_text_len = 50 257 | input_text_embed_size = 768 258 | vit = 'ViT-B/32' 259 | train_transform_keys = ["clip"] 260 | val_transform_keys = ["clip"] 261 | input_image_embed_size = 768 262 | image_size = 576 263 | 264 | 265 | @ex.named_config 266 | def ft_irtr_coco(): 267 | exp_name = "finetune_irtr_coco" 268 | datasets = ["coco"] 269 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 270 | batch_size = 512 271 | max_epoch = 10 272 | max_steps = None 273 | warmup_steps = 0.1 274 | get_recall_metric = True 275 | draw_false_text = 15 276 | learning_rate = 5e-6 277 | lr_mult_head = 5 278 | lr_mult_cross_modal = 5 279 | tokenizer = ".cache/bert-base-uncased" 280 | input_text_embed_size = 768 281 | vit = 'ViT-B/32' 282 | train_transform_keys = ["clip"] 283 | val_transform_keys = ["clip"] 284 | input_image_embed_size = 768 285 | image_size = 384 286 | 287 | 288 | @ex.named_config 289 | def ft_cap_coco(): 290 | exp_name = "finetune_caption_coco" 291 | 292 | loss_names = _loss_names({"ar": 0.5}) 293 | batch_size = 256 294 | max_epoch = 20 295 | max_steps = None 296 | warmup_steps = 0.1 297 | get_caption_metric = True 298 | get_mlm_caption_metric = False 299 | get_recall_metric = False 300 | draw_false_text = 0 301 | learning_rate = 3e-5 302 | lr_mult_head = 5 303 | lr_mult_cross_modal = 5 304 | tokenizer = ".cache/bert-base-uncased" 305 | input_text_embed_size = 768 306 | vit = 'ViT-B/32' 307 | train_transform_keys = ["clip"] 308 | val_transform_keys = ["clip"] 309 | input_image_embed_size = 768 310 | image_size = 384 311 | 312 | caption_prompt = '' 313 | add_new_bos_token = True 314 | prepend_bos_token = True 315 | append_eos_token = True 316 | datasets = ["coco"] 317 | per_gpu_batchsize = 64 318 | 319 | 320 | # @ex.named_config 321 | # def add_bos_eos_tokens(): 322 | # add_new_bos_token=True 323 | # prepend_bos_token=True 324 | # append_eos_token=True 325 | 326 | @ex.named_config 327 | def zs_irtr_coco(): 328 | test_only = True 329 | skip_test_step = True 330 | get_recall_metric = True 331 | get_cl_recall_metric = False 332 | 333 | exp_name = "zs_irtr_coco" 334 | datasets = ["coco"] 335 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 336 | batch_size = 512 337 | max_epoch = 10 338 | max_steps = None 339 | warmup_steps = 0.1 340 | get_recall_metric = True 341 | draw_false_text = 15 342 | learning_rate = 5e-6 343 | lr_mult_head = 5 344 | lr_mult_cross_modal = 5 345 | tokenizer = ".cache/bert-base-uncased" 346 | input_text_embed_size = 768 347 | vit = 'ViT-B/32' 348 | train_transform_keys = ["clip"] 349 | val_transform_keys = ["clip"] 350 | input_image_embed_size = 768 351 | image_size = 384 352 | 353 | 354 | @ex.named_config 355 | def ft_irtr_f30k(): 356 | exp_name = "finetune_irtr_f30k" 357 | datasets = ["f30k"] 358 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 359 | batch_size = 512 360 | max_epoch = 10 361 | max_steps = None 362 | warmup_steps = 0.1 363 | get_recall_metric = True 364 | draw_false_text = 15 365 | learning_rate = 5e-6 366 | lr_mult_head = 5 367 | lr_mult_cross_modal = 5 368 | tokenizer = ".cache/bert-base-uncased" 369 | input_text_embed_size = 768 370 | vit = 'ViT-B/32' 371 | train_transform_keys = ["clip"] 372 | val_transform_keys = ["clip"] 373 | input_image_embed_size = 768 374 | image_size = 384 375 | 376 | 377 | @ex.named_config 378 | def ft_cl_itm_irtr_f30k(): 379 | exp_name = "finetune_irtr_f30k" 380 | datasets = ["f30k"] 381 | loss_names = _loss_names({"itm": 0.5, "irtr": 1, "cl": 1}) 382 | batch_size = 512 383 | max_epoch = 10 384 | max_steps = None 385 | warmup_steps = 0.1 386 | get_recall_metric = False 387 | get_cl_itm_recall_metric = True 388 | draw_false_text = 15 389 | learning_rate = 5e-6 390 | lr_mult_head = 5 391 | lr_mult_cross_modal = 5 392 | tokenizer = ".cache/bert-base-uncased" 393 | input_text_embed_size = 768 394 | vit = 'ViT-B/32' 395 | train_transform_keys = ["clip"] 396 | val_transform_keys = ["clip"] 397 | input_image_embed_size = 768 398 | image_size = 384 399 | 400 | 401 | @ex.named_config 402 | def zs_irtr_f30k(): 403 | test_only = True 404 | skip_test_step = True 405 | get_recall_metric = True 406 | get_cl_recall_metric = False 407 | 408 | exp_name = "zeroshot_irtr_f30k" 409 | datasets = ["f30k"] 410 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 411 | batch_size = 512 412 | max_epoch = 10 413 | max_steps = None 414 | warmup_steps = 0.1 415 | get_recall_metric = True 416 | draw_false_text = 15 417 | learning_rate = 5e-6 418 | lr_mult_head = 5 419 | lr_mult_cross_modal = 5 420 | tokenizer = ".cache/bert-base-uncased" 421 | input_text_embed_size = 768 422 | vit = 'ViT-B/32' 423 | train_transform_keys = ["clip"] 424 | val_transform_keys = ["clip"] 425 | input_image_embed_size = 768 426 | image_size = 384 427 | 428 | 429 | @ex.named_config 430 | def ft_cl_irtr_f30k(): 431 | exp_name = "finetune_cl_irtr_f30k" 432 | datasets = ["f30k"] 433 | loss_names = _loss_names({"cl": 1.0}) 434 | batch_size = 512 435 | max_epoch = 10 436 | max_steps = None 437 | warmup_steps = 0.1 438 | get_recall_metric = False 439 | get_cl_recall_metric = True 440 | draw_false_text = 15 441 | learning_rate = 5e-6 442 | lr_mult_head = 5 443 | lr_mult_cross_modal = 5 444 | tokenizer = ".cache/bert-base-uncased" 445 | input_text_embed_size = 768 446 | vit = 'ViT-B/32' 447 | train_transform_keys = ["clip"] 448 | val_transform_keys = ["clip"] 449 | input_image_embed_size = 768 450 | image_size = 384 451 | 452 | 453 | @ex.named_config 454 | def zs_cl_irtr_f30k(): 455 | test_only = True 456 | skip_test_step = True 457 | get_recall_metric = False 458 | get_cl_recall_metric = True 459 | 460 | exp_name = "zs_cl_irtr_f30k" 461 | datasets = ["f30k"] 462 | loss_names = _loss_names({"cl": 1.0}) 463 | batch_size = 512 464 | max_epoch = 10 465 | max_steps = None 466 | warmup_steps = 0.1 467 | get_recall_metric = False 468 | get_cl_recall_metric = True 469 | draw_false_text = 15 470 | learning_rate = 5e-6 471 | lr_mult_head = 5 472 | lr_mult_cross_modal = 5 473 | tokenizer = ".cache/bert-base-uncased" 474 | input_text_embed_size = 768 475 | vit = 'ViT-B/32' 476 | train_transform_keys = ["clip"] 477 | val_transform_keys = ["clip"] 478 | input_image_embed_size = 768 479 | image_size = 384 480 | 481 | 482 | @ex.named_config 483 | def zs_cl_irtr_coco(): 484 | test_only = True 485 | skip_test_step = True 486 | get_recall_metric = False 487 | get_cl_recall_metric = True 488 | 489 | exp_name = "zs_cl_irtr_coco" 490 | datasets = ["coco"] 491 | loss_names = _loss_names({"cl": 0.5}) 492 | batch_size = 512 493 | max_epoch = 10 494 | max_steps = None 495 | warmup_steps = 0.1 496 | get_recall_metric = False 497 | get_cl_recall_metric = True 498 | draw_false_text = 15 499 | learning_rate = 5e-6 500 | lr_mult_head = 5 501 | lr_mult_cross_modal = 5 502 | tokenizer = ".cache/bert-base-uncased" 503 | input_text_embed_size = 768 504 | vit = 'ViT-B/32' 505 | train_transform_keys = ["clip"] 506 | val_transform_keys = ["clip"] 507 | input_image_embed_size = 768 508 | image_size = 384 509 | 510 | 511 | @ex.named_config 512 | def ft_snli_clip_bert(): 513 | exp_name = "finetune_snli" 514 | datasets = ["snli"] 515 | loss_names = _loss_names({"snli": 1}) 516 | batch_size = 64 517 | max_epoch = 5 518 | max_steps = None 519 | warmup_steps = 0.1 520 | draw_false_image = 0 521 | learning_rate = 2e-6 522 | lr_mult_head = 10 523 | lr_mult_cross_modal = 5 524 | tokenizer = ".cache/bert-base-uncased" 525 | max_text_len = 50 526 | input_text_embed_size = 768 527 | vit = 'ViT-B/32' 528 | train_transform_keys = ["clip"] 529 | val_transform_keys = ["clip"] 530 | input_image_embed_size = 768 531 | image_size = 384 532 | 533 | 534 | # Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end 535 | 536 | # vision encoder 537 | @ex.named_config 538 | def swin32_base224(): 539 | vit = "swin_base_patch4_window7_224_in22k" 540 | patch_size = 32 541 | image_size = 224 542 | train_transform_keys = ["imagenet"] 543 | val_transform_keys = ["imagenet"] 544 | input_image_embed_size = 1024 545 | resolution_before = 224 546 | 547 | 548 | @ex.named_config 549 | def swin32_base384(): 550 | vit = "swin_base_patch4_window12_384_in22k" 551 | patch_size = 32 552 | image_size = 384 553 | train_transform_keys = ["imagenet"] 554 | val_transform_keys = ["imagenet"] 555 | input_image_embed_size = 1024 556 | resolution_before = 384 557 | 558 | 559 | @ex.named_config 560 | def swin32_large384(): 561 | vit = "swin_large_patch4_window12_384_in22k" 562 | patch_size = 32 563 | image_size = 384 564 | train_transform_keys = ["imagenet"] 565 | val_transform_keys = ["imagenet"] 566 | input_image_embed_size = 1536 567 | resolution_before = 384 568 | 569 | 570 | @ex.named_config 571 | def clip32(): 572 | vit = 'ViT-B/32' 573 | patch_size = 32 574 | train_transform_keys = ["clip"] 575 | val_transform_keys = ["clip"] 576 | input_image_embed_size = 768 577 | 578 | 579 | @ex.named_config 580 | def clip16(): 581 | vit = 'ViT-B/16' 582 | patch_size = 16 583 | train_transform_keys = ["clip"] 584 | val_transform_keys = ["clip"] 585 | input_image_embed_size = 768 586 | 587 | 588 | @ex.named_config 589 | def clip14(): 590 | vit = 'ViT-L/14' 591 | patch_size = 14 592 | train_transform_keys = ["clip"] 593 | val_transform_keys = ["clip"] 594 | input_image_embed_size = 1024 595 | 596 | 597 | @ex.named_config 598 | def clip14_336(): 599 | vit = 'ViT-L/14@336px' 600 | image_size = 336 601 | patch_size = 14 602 | train_transform_keys = ["clip"] 603 | val_transform_keys = ["clip"] 604 | input_image_embed_size = 1024 605 | 606 | 607 | @ex.named_config 608 | def mae_vit_huge_patch14(): 609 | vit = 'mae_vit_huge_patch14' 610 | image_size = 224 611 | patch_size = 14 612 | train_transform_keys = ["mae"] 613 | val_transform_keys = ["mae"] 614 | 615 | 616 | @ex.named_config 617 | def mae_vit_large_patch16(): 618 | vit = 'mae_vit_large_patch16' 619 | image_size = 224 620 | patch_size = 16 621 | train_transform_keys = ["mae"] 622 | val_transform_keys = ["mae"] 623 | 624 | 625 | @ex.named_config 626 | def mae_vit_base_patch16(): 627 | vit = 'mae_vit_base_patch16' 628 | image_size = 224 629 | patch_size = 16 630 | train_transform_keys = ["mae"] 631 | val_transform_keys = ["mae"] 632 | 633 | # text encoder 634 | 635 | 636 | @ex.named_config 637 | def text_roberta(): 638 | tokenizer = ".cache/roberta-base" 639 | vocab_size = 50265 640 | input_text_embed_size = 768 641 | 642 | 643 | # @ex.named_config 644 | # def text_clip(): 645 | # tokenizer = ".cache/roberta-base" 646 | # vocab_size = 50265 647 | # input_text_embed_size = 768 648 | 649 | @ex.named_config 650 | def text_roberta_large(): 651 | tokenizer = ".cache/roberta-large" 652 | vocab_size = 50265 653 | input_text_embed_size = 1024 654 | 655 | 656 | # random augmentation 657 | @ex.named_config 658 | def imagenet_randaug(): 659 | train_transform_keys = ["imagenet_randaug"] 660 | 661 | 662 | @ex.named_config 663 | def clip_randaug(): 664 | train_transform_keys = ["clip_randaug"] 665 | 666 | 667 | @ex.named_config 668 | def mae_randaug(): 669 | train_transform_keys = ["mae_randaug"] 670 | -------------------------------------------------------------------------------- /flm/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_datamodule import VisualGenomeCaptionDataModule 2 | from .f30k_caption_karpathy_datamodule import F30KCaptionKarpathyDataModule 3 | from .coco_caption_karpathy_datamodule import CocoCaptionKarpathyDataModule 4 | from .conceptual_caption_datamodule import ConceptualCaptionDataModule 5 | from .sbu_datamodule import SBUCaptionDataModule 6 | from .vqav2_datamodule import VQAv2DataModule 7 | from .nlvr2_datamodule import NLVR2DataModule 8 | from .snli_datamodule import SNLIDataModule 9 | from .conceptual_caption12m_datamodule import ConceptualCaption12mDataModule 10 | # from .conceptual_caption8m_datamodule import ConceptualCaption8mDataModule 11 | from .laion_datamodule import LaionDataModule 12 | from .laion100m_datamodule import Laion100mDataModule 13 | # from .wino_datamodule import WinoDataModule 14 | 15 | _datamodules = { 16 | "vg": VisualGenomeCaptionDataModule, 17 | "f30k": F30KCaptionKarpathyDataModule, 18 | "coco": CocoCaptionKarpathyDataModule, 19 | "gcc": ConceptualCaptionDataModule, 20 | "sbu": SBUCaptionDataModule, 21 | "vqa": VQAv2DataModule, 22 | "nlvr2": NLVR2DataModule, 23 | "snli": SNLIDataModule, 24 | "gcc12m": ConceptualCaption12mDataModule, 25 | # "gcc8m": ConceptualCaption8mDataModule, 26 | "laion": LaionDataModule, 27 | "laion100m": Laion100mDataModule, 28 | # "wino": WinoDataModule 29 | } 30 | -------------------------------------------------------------------------------- /flm/datamodules/coco_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import CocoCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # COCO Caption datamodule 6 | class CocoCaptionKarpathyDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return CocoCaptionKarpathyDataset 13 | 14 | @property 15 | def dataset_cls_no_false(self): 16 | return CocoCaptionKarpathyDataset 17 | 18 | @property 19 | def dataset_name(self): 20 | return "coco" 21 | -------------------------------------------------------------------------------- /flm/datamodules/conceptual_caption12m_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import ConceptualCaption12mDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # Conceptual Caption 12M datamodule 6 | class ConceptualCaption12mDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return ConceptualCaption12mDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "gcc" 17 | -------------------------------------------------------------------------------- /flm/datamodules/conceptual_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import ConceptualCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # Conceptual Caption 3M datamodule 6 | class ConceptualCaptionDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return ConceptualCaptionDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "gcc" 17 | -------------------------------------------------------------------------------- /flm/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | import torch 3 | import functools 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader 6 | from transformers import ( 7 | DataCollatorForLanguageModeling, 8 | # DataCollatorForWholeWordMask, 9 | BertTokenizer, 10 | RobertaTokenizer, 11 | ) 12 | 13 | from flm.utils.whole_word_masking import DataCollatorForWholeWordMask 14 | 15 | 16 | class text_preprocessor(): 17 | """prepend or append special tokens""" 18 | 19 | def __init__(self, config) -> None: 20 | self.prepend_bos = config['add_new_bos_token'] and config['prepend_bos_token'] 21 | self.append_eos = config['add_new_bos_token'] and config['append_eos_token'] 22 | 23 | def __call__(self, text): 24 | text = text.rstrip().rstrip('.').rstrip() + '.' 25 | if self.prepend_bos: 26 | text = '' + ' ' + text 27 | if self.append_eos: 28 | text = text + ' ' + '' 29 | return text 30 | 31 | 32 | def flm_collator(attention_mask, mask_ratio, disable_shuffle=True, label_strategy='none'): 33 | """get flm masks and labels""" 34 | text_len = attention_mask.sum(1) 35 | bs, max_len = attention_mask.size() 36 | flm_masks = -10000. * torch.ones(bs, max_len, max_len) 37 | # attention_mask.unsqueeze(dim=2) * attention_mask.unsqueeze(dim=1) 38 | flm_random_ids = [] 39 | mask_num = torch.distributions.Binomial( 40 | text_len.float() - 1, mask_ratio).sample().int() 41 | for i in range(len(text_len)): 42 | flm_random_id = torch.randperm(text_len[i] - 1) + 1 43 | flm_random_id = flm_random_id[:text_len[i] - 1 - mask_num[i]] 44 | if disable_shuffle: 45 | flm_random_id = torch.sort(flm_random_id)[0] 46 | flm_random_ids.append(flm_random_id) 47 | # print(flm_random_id) 48 | for j in range(len(flm_random_id)): 49 | if flm_random_id[j] < 0: 50 | break 51 | else: 52 | flm_masks[i, 53 | flm_random_id[j:j + 1].repeat(j+1), 54 | flm_random_id[:j+1]] = 0 55 | 56 | flm_label = None 57 | if label_strategy == 'none': 58 | pass 59 | else: 60 | 61 | if label_strategy == 'object': 62 | pass 63 | elif label_strategy == 'concrete': 64 | pass 65 | return flm_random_ids, flm_masks, flm_label 66 | 67 | 68 | def sep_collator(flatten_encodings, mlm_collator, mask_ratio, pred_corr_ratio) -> None: 69 | if pred_corr_ratio > 1: 70 | repeat_num = int(pred_corr_ratio) 71 | group_mlms = [[] for i in range(repeat_num)] 72 | mlms = mlm_collator(flatten_encodings) 73 | # print('mlms', mlms) 74 | for idx, flatten_encoding in enumerate(flatten_encodings): 75 | token_num = len(flatten_encoding['attention_mask']) 76 | chunk_size = token_num // repeat_num + 1 77 | org_input_id = torch.tensor(flatten_encoding['input_ids']) 78 | mlm_input_id = mlms['input_ids'][idx] 79 | mlm_labels = mlms['labels'][idx] 80 | ava_mask_reg = torch.tensor(flatten_encoding['attention_mask']) * ( 81 | 1 - torch.tensor(flatten_encoding['special_tokens_mask'])) 82 | perm = torch.randperm(token_num) 83 | groups = perm.split(chunk_size) 84 | assert len(groups) == repeat_num 85 | for i in range(repeat_num): 86 | group_mask = torch.zeros(token_num).long() 87 | group_mask[groups[i]] = 1 88 | group_input_id = org_input_id * \ 89 | (1-group_mask) + mlm_input_id * group_mask 90 | group_label = -100 * torch.ones(token_num).long() 91 | group_label[group_mask.bool()] = mlm_labels[group_mask.bool()] 92 | group_mlm = {'input_ids': group_input_id, 93 | 'labels': group_label} 94 | group_mlms[i].append(group_mlm) 95 | # print(group_mask) 96 | for i in range(repeat_num): 97 | group_mlms[i] = {'input_ids': torch.stack([_['input_ids'] for _ in group_mlms[i]]), 98 | 'labels': torch.stack([_['labels'] for _ in group_mlms[i]])} 99 | return group_mlms 100 | 101 | elif pred_corr_ratio < 1: 102 | mlms = mlm_collator(flatten_encodings) 103 | group_labels = [] 104 | # print('mlms', mlms) 105 | for idx, flatten_encoding in enumerate(flatten_encodings): 106 | token_num = len(flatten_encoding['attention_mask']) 107 | mlm_input_id = mlms['input_ids'][idx] 108 | mlm_labels = mlms['labels'][idx] 109 | perm = torch.randperm(token_num)[:int(token_num * pred_corr_ratio)] 110 | group_label = -100 * torch.ones(token_num).long() 111 | group_label[perm] = mlm_labels[perm] 112 | group_labels.append(group_label) 113 | 114 | group_mlm = {'input_ids': mlms['input_ids'], 115 | 'labels': torch.stack(group_labels, dim=0)} 116 | return group_mlm 117 | 118 | 119 | def get_pretrained_tokenizer(from_pretrained): 120 | if torch.distributed.is_initialized(): 121 | if torch.distributed.get_rank() == 0: 122 | if 'roberta' in from_pretrained: 123 | RobertaTokenizer.from_pretrained(from_pretrained) 124 | else: 125 | BertTokenizer.from_pretrained( 126 | from_pretrained, do_lower_case="uncased" in from_pretrained 127 | ) 128 | torch.distributed.barrier() 129 | 130 | if 'roberta' in from_pretrained: 131 | return RobertaTokenizer.from_pretrained(from_pretrained) 132 | elif 'gpt2' in from_pretrained: 133 | from transformers import GPT2Tokenizer, GPT2Model 134 | return GPT2Tokenizer.from_pretrained('gpt2') 135 | return BertTokenizer.from_pretrained( 136 | from_pretrained, do_lower_case="uncased" in from_pretrained 137 | ) 138 | 139 | 140 | class BaseDataModule(LightningDataModule): 141 | def __init__(self, _config): 142 | super().__init__() 143 | self.data_dir = _config["data_root"] 144 | 145 | self.num_workers = _config["num_workers"] 146 | self.batch_size = _config["per_gpu_batchsize"] 147 | self.eval_batch_size = self.batch_size 148 | 149 | self.image_size = _config["image_size"] 150 | self.max_text_len = _config["max_text_len"] 151 | self.draw_false_image = _config["draw_false_image"] 152 | self.draw_false_text = _config["draw_false_text"] 153 | self.image_only = _config["image_only"] 154 | 155 | self.train_transform_keys = ( 156 | ["default_train"] 157 | if len(_config["train_transform_keys"]) == 0 158 | else _config["train_transform_keys"] 159 | ) 160 | 161 | self.val_transform_keys = ( 162 | ["default_val"] 163 | if len(_config["val_transform_keys"]) == 0 164 | else _config["val_transform_keys"] 165 | ) 166 | 167 | tokenizer = _config["tokenizer"] 168 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 169 | if _config['add_new_bos_token']: 170 | self.tokenizer.add_tokens(['', '']) 171 | self.vocab_size = self.tokenizer.vocab_size 172 | 173 | collator = ( 174 | DataCollatorForWholeWordMask 175 | if _config["whole_word_masking"] 176 | else DataCollatorForLanguageModeling 177 | ) 178 | 179 | self.mlm_collator = {'mlm_collator': 180 | collator(tokenizer=self.tokenizer, 181 | mlm=True, 182 | mlm_probability=_config["mlm_prob"]), 183 | "flm_collator": 184 | functools.partial( 185 | flm_collator, 186 | mask_ratio=_config["flm_mask_prob"], 187 | disable_shuffle=_config["disable_flm_shuffle"]), 188 | } 189 | 190 | self.text_preprocessor = text_preprocessor(_config) 191 | self.setup_flag = False 192 | self.max_dataset_len = _config.get('max_dataset_len', -1) 193 | 194 | @property 195 | def dataset_cls(self): 196 | raise NotImplementedError("return tuple of dataset class") 197 | 198 | @property 199 | def dataset_name(self): 200 | raise NotImplementedError("return name of dataset") 201 | 202 | def set_train_dataset(self): 203 | self.train_dataset = self.dataset_cls( 204 | self.data_dir, 205 | self.train_transform_keys, 206 | split="train", 207 | image_size=self.image_size, 208 | max_text_len=self.max_text_len, 209 | draw_false_image=self.draw_false_image, 210 | draw_false_text=self.draw_false_text, 211 | image_only=self.image_only, 212 | tokenizer=self.tokenizer, 213 | disable_sep_mlm=False, 214 | text_preprocessor=self.text_preprocessor, 215 | max_dataset_len=self.max_dataset_len 216 | ) 217 | 218 | def set_val_dataset(self): 219 | self.val_dataset = self.dataset_cls( 220 | self.data_dir, 221 | self.val_transform_keys, 222 | split="val", 223 | image_size=self.image_size, 224 | max_text_len=self.max_text_len, 225 | draw_false_image=self.draw_false_image, 226 | draw_false_text=self.draw_false_text, 227 | image_only=self.image_only, 228 | tokenizer=self.tokenizer, 229 | text_preprocessor=self.text_preprocessor, 230 | max_dataset_len=self.max_dataset_len 231 | ) 232 | 233 | if hasattr(self, "dataset_cls_no_false"): 234 | self.val_dataset_no_false = self.dataset_cls_no_false( 235 | self.data_dir, 236 | self.val_transform_keys, 237 | split="val", 238 | image_size=self.image_size, 239 | max_text_len=self.max_text_len, 240 | draw_false_image=0, 241 | draw_false_text=0, 242 | image_only=self.image_only, 243 | tokenizer=self.tokenizer, 244 | text_preprocessor=self.text_preprocessor, 245 | max_dataset_len=self.max_dataset_len 246 | ) 247 | 248 | def make_no_false_val_dset(self, image_only=False): 249 | return self.dataset_cls_no_false( 250 | self.data_dir, 251 | self.val_transform_keys, 252 | split="val", 253 | image_size=self.image_size, 254 | max_text_len=self.max_text_len, 255 | draw_false_image=0, 256 | draw_false_text=0, 257 | image_only=image_only, 258 | tokenizer=self.tokenizer, 259 | text_preprocessor=self.text_preprocessor, 260 | max_dataset_len=self.max_dataset_len 261 | ) 262 | 263 | def set_test_dataset(self): 264 | self.test_dataset = self.dataset_cls( 265 | self.data_dir, 266 | self.val_transform_keys, 267 | split="test", 268 | image_size=self.image_size, 269 | max_text_len=self.max_text_len, 270 | draw_false_image=self.draw_false_image, 271 | draw_false_text=self.draw_false_text, 272 | image_only=self.image_only, 273 | tokenizer=self.tokenizer, 274 | text_preprocessor=self.text_preprocessor, 275 | max_dataset_len=self.max_dataset_len 276 | ) 277 | 278 | def setup(self, stage): 279 | if not self.setup_flag: 280 | self.set_train_dataset() 281 | self.set_val_dataset() 282 | self.set_test_dataset() 283 | 284 | self.train_dataset.tokenizer = self.tokenizer 285 | self.val_dataset.tokenizer = self.tokenizer 286 | self.test_dataset.tokenizer = self.tokenizer 287 | 288 | self.setup_flag = True 289 | 290 | def train_dataloader(self): 291 | loader = DataLoader( 292 | self.train_dataset, 293 | batch_size=self.batch_size, 294 | shuffle=True, 295 | num_workers=self.num_workers, 296 | pin_memory=True, 297 | collate_fn=self.train_dataset.collate, 298 | ) 299 | return loader 300 | 301 | def val_dataloader(self): 302 | loader = DataLoader( 303 | self.val_dataset, 304 | batch_size=self.eval_batch_size, 305 | shuffle=False, 306 | num_workers=self.num_workers, 307 | pin_memory=True, 308 | collate_fn=self.val_dataset.collate, 309 | ) 310 | return loader 311 | 312 | def test_dataloader(self): 313 | loader = DataLoader( 314 | self.test_dataset, 315 | batch_size=self.eval_batch_size, 316 | shuffle=False, 317 | num_workers=self.num_workers, 318 | pin_memory=True, 319 | collate_fn=self.test_dataset.collate, 320 | ) 321 | return loader 322 | -------------------------------------------------------------------------------- /flm/datamodules/f30k_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import F30KCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # Flickr30K datamodule 6 | class F30KCaptionKarpathyDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return F30KCaptionKarpathyDataset 13 | 14 | @property 15 | def dataset_cls_no_false(self): 16 | return F30KCaptionKarpathyDataset 17 | 18 | @property 19 | def dataset_name(self): 20 | return "f30k" 21 | 22 | def train_dataloader(self): 23 | loader = DataLoader( 24 | self.train_dataset, 25 | batch_size=self.batch_size, 26 | shuffle=True, 27 | num_workers=0, 28 | pin_memory=True, 29 | collate_fn=self.train_dataset.collate, 30 | ) 31 | return loader 32 | 33 | def val_dataloader(self): 34 | loader = DataLoader( 35 | self.val_dataset, 36 | batch_size=self.eval_batch_size, 37 | shuffle=False, 38 | num_workers=0, 39 | pin_memory=True, 40 | collate_fn=self.val_dataset.collate, 41 | ) 42 | return loader 43 | 44 | def test_dataloader(self): 45 | loader = DataLoader( 46 | self.test_dataset, 47 | batch_size=self.eval_batch_size, 48 | shuffle=False, 49 | num_workers=0, 50 | pin_memory=True, 51 | collate_fn=self.test_dataset.collate, 52 | ) 53 | return loader 54 | -------------------------------------------------------------------------------- /flm/datamodules/laion100m_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import Laion100mDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # LAION-100M datamodule, a random subset of LAION-400M 6 | class Laion100mDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return Laion100mDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "laion" 17 | -------------------------------------------------------------------------------- /flm/datamodules/laion_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import LaionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # LAION-400M datamodule 6 | class LaionDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return LaionDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "laion" 17 | -------------------------------------------------------------------------------- /flm/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | from builtins import hasattr 2 | import functools 3 | 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.dataset import ConcatDataset 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | from . import _datamodules 10 | import webdataset as wds 11 | 12 | 13 | # datamodule for mutiple datasets 14 | class MTDataModule(LightningDataModule): 15 | def __init__(self, _config, dist=False): 16 | datamodule_keys = _config["datasets"] 17 | assert len(datamodule_keys) > 0 18 | 19 | super().__init__() 20 | 21 | self.dm_keys = datamodule_keys 22 | self.dm_dicts = {key: _datamodules[key]( 23 | _config) for key in datamodule_keys} 24 | self.dms = [v for k, v in self.dm_dicts.items()] 25 | 26 | self.batch_size = self.dms[0].batch_size 27 | self.vocab_size = self.dms[0].vocab_size 28 | self.num_workers = self.dms[0].num_workers 29 | 30 | self.dist = dist 31 | self.allow_val_webdataset = _config['allow_val_webdataset'] 32 | 33 | def prepare_data(self): 34 | for dm in self.dms: 35 | dm.prepare_data() 36 | 37 | def setup(self, stage): 38 | def check_webdataset(dataset): 39 | if hasattr(dataset, 'inner_dataset'): 40 | return True 41 | 42 | for dm in self.dms: 43 | dm.setup(stage) 44 | 45 | if check_webdataset(self.dms[0].train_dataset): 46 | assert len( 47 | self.dms) == 1, 'does not support webdataset instance larger than 1' 48 | self.train_dataset = self.dms[0].train_dataset.inner_dataset 49 | # self.train_dataset.append(wds.batched(self.batch_size)) 50 | else: 51 | self.train_dataset = ConcatDataset( 52 | [dm.train_dataset for dm in self.dms]) 53 | 54 | if check_webdataset(self.dms[0].val_dataset) and self.allow_val_webdataset: 55 | self.val_dataset = self.dms[0].val_dataset.inner_dataset 56 | # self.val_dataset.append(wds.batched(self.batch_size)) 57 | else: 58 | self.val_dataset = ConcatDataset( 59 | [dm.val_dataset for dm in self.dms]) 60 | 61 | if check_webdataset(self.dms[0].test_dataset) and self.allow_val_webdataset: 62 | self.test_dataset = self.dms[0].test_dataset.inner_dataset 63 | # self.test_dataset.append(wds.batched(self.batch_size)) 64 | else: 65 | self.test_dataset = ConcatDataset( 66 | [dm.test_dataset for dm in self.dms]) 67 | 68 | self.tokenizer = self.dms[0].tokenizer 69 | 70 | self.train_collate = functools.partial( 71 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator 72 | ) 73 | self.val_collate = functools.partial( 74 | self.dms[0].val_dataset.collate, mlm_collator=self.dms[0].mlm_collator 75 | ) 76 | self.test_collate = functools.partial( 77 | self.dms[0].test_dataset.collate, mlm_collator=self.dms[0].mlm_collator 78 | ) 79 | 80 | if self.dist: 81 | if isinstance(self.train_dataset, wds.DataPipeline): 82 | self.train_sampler = None 83 | else: 84 | self.train_sampler = DistributedSampler( 85 | self.train_dataset, shuffle=True) 86 | if isinstance(self.val_dataset, wds.DataPipeline) and self.allow_val_webdataset: 87 | self.val_sampler = None 88 | else: 89 | self.val_sampler = DistributedSampler( 90 | self.val_dataset, shuffle=True) 91 | if isinstance(self.test_dataset, wds.DataPipeline) and self.allow_val_webdataset: 92 | self.test_sampler = None 93 | else: 94 | self.test_sampler = DistributedSampler( 95 | self.test_dataset, shuffle=False) 96 | 97 | else: 98 | self.train_sampler = None 99 | self.val_sampler = None 100 | self.test_sampler = None 101 | 102 | def train_dataloader(self): 103 | loader = DataLoader( 104 | self.train_dataset, 105 | batch_size=self.batch_size, 106 | sampler=self.train_sampler, 107 | num_workers=self.num_workers, 108 | collate_fn=self.train_collate, 109 | ) 110 | return loader 111 | 112 | def val_dataloader(self, batch_size=None): 113 | loader = DataLoader( 114 | self.val_dataset, 115 | batch_size=batch_size if batch_size is not None else self.batch_size, 116 | sampler=self.val_sampler, 117 | num_workers=self.num_workers, 118 | collate_fn=self.val_collate, 119 | ) 120 | return loader 121 | 122 | def test_dataloader(self): 123 | loader = DataLoader( 124 | self.test_dataset, 125 | batch_size=self.batch_size, 126 | sampler=self.test_sampler, 127 | num_workers=self.num_workers, 128 | collate_fn=self.test_collate, 129 | ) 130 | return loader 131 | -------------------------------------------------------------------------------- /flm/datamodules/nlvr2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import NLVR2Dataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # NLVR2 datamodule 6 | class NLVR2DataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return NLVR2Dataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "nlvr2" 17 | -------------------------------------------------------------------------------- /flm/datamodules/sbu_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SBUCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # SBU Caption datamodule 6 | class SBUCaptionDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return SBUCaptionDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "sbu" 17 | -------------------------------------------------------------------------------- /flm/datamodules/snli_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SNLIDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | # SNLI datamodule 7 | class SNLIDataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return SNLIDataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "snli" 18 | -------------------------------------------------------------------------------- /flm/datamodules/vg_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VisualGenomeCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | # VisualGenome datamodule 6 | class VisualGenomeCaptionDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return VisualGenomeCaptionDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "vg" 17 | -------------------------------------------------------------------------------- /flm/datamodules/vqav2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VQAv2Dataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | # VQAv2 datamodule 7 | class VQAv2DataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return VQAv2Dataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "vqa" 18 | 19 | def setup(self, stage): 20 | super().setup(stage) 21 | 22 | train_answers = self.train_dataset.table["answers"].to_pandas( 23 | ).tolist() 24 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 25 | train_labels = self.train_dataset.table["answer_labels"].to_pandas( 26 | ).tolist() 27 | val_labels = self.val_dataset.table["answer_labels"].to_pandas( 28 | ).tolist() 29 | 30 | all_answers = [c for c in train_answers + val_answers if c is not None] 31 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 32 | all_labels = [c for c in train_labels + val_labels if c is not None] 33 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 34 | 35 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 36 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 37 | self.num_class = max(self.answer2id.values()) + 1 38 | 39 | self.id2answer = defaultdict(lambda: "unknown") 40 | for k, v in sorted_a2i: 41 | self.id2answer[v] = k 42 | -------------------------------------------------------------------------------- /flm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_dataset import VisualGenomeCaptionDataset 2 | from .coco_caption_karpathy_dataset import CocoCaptionKarpathyDataset 3 | from .f30k_caption_karpathy_dataset import F30KCaptionKarpathyDataset 4 | from .conceptual_caption_dataset import ConceptualCaptionDataset 5 | from .conceptual_caption12m_dataset import ConceptualCaption12mDataset 6 | from .sbu_caption_dataset import SBUCaptionDataset 7 | from .vqav2_dataset import VQAv2Dataset 8 | from .nlvr2_dataset import NLVR2Dataset 9 | from .snli_dataset import SNLIDataset 10 | from .laion_dataset import LaionDataset 11 | from .laion100m_dataset import Laion100mDataset 12 | -------------------------------------------------------------------------------- /flm/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import io 4 | import pyarrow as pa 5 | import os 6 | import pdb 7 | from PIL import Image 8 | from ..transforms import keys_to_transforms 9 | import pdb 10 | import copy 11 | 12 | 13 | class BaseDataset(torch.utils.data.Dataset): 14 | def __init__( 15 | self, 16 | data_dir: str, 17 | transform_keys: list, 18 | image_size: int, 19 | names: list, 20 | text_column_name: str = "", 21 | remove_duplicate=True, 22 | max_text_len=40, 23 | max_dataset_len=-1, 24 | draw_false_image=0, 25 | draw_false_text=0, 26 | image_only=False, 27 | tokenizer=None, 28 | disable_sep_mlm=True, 29 | text_preprocessor=None, 30 | ): 31 | """ 32 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 33 | transform_keys : keys for generating augmented views of images 34 | text_column_name : pyarrow table column name that has list of strings as elements 35 | """ 36 | assert len(transform_keys) >= 1 37 | super().__init__() 38 | 39 | self.transforms = keys_to_transforms(transform_keys, size=image_size) 40 | self.clip_transform = False 41 | for transform_key in transform_keys: 42 | if 'clip' in transform_key: 43 | self.clip_transform = True 44 | break 45 | self.text_column_name = text_column_name 46 | self.names = names 47 | self.max_text_len = max_text_len 48 | self.draw_false_image = draw_false_image 49 | self.draw_false_text = draw_false_text 50 | self.image_only = image_only 51 | self.data_dir = data_dir 52 | self.disable_sep_mlm = disable_sep_mlm 53 | self.text_preprocessor = text_preprocessor 54 | 55 | if len(names) != 0: 56 | tables = [ 57 | pa.ipc.RecordBatchFileReader( 58 | pa.memory_map(f"{data_dir}/{name}.arrow", "r") 59 | ).read_all() 60 | for name in names 61 | if os.path.isfile(f"{data_dir}/{name}.arrow") 62 | ] 63 | self.table_names = list() 64 | for i, name in enumerate(names): 65 | self.table_names += [name] * len(tables[i]) 66 | 67 | if max_dataset_len != -1: 68 | self.table = pa.concat_tables(tables, promote=True)[ 69 | :max_dataset_len] 70 | print(' truncate the dataset with length: {}'.format(max_dataset_len)) 71 | else: 72 | self.table = pa.concat_tables(tables, promote=True) 73 | 74 | if text_column_name != "": 75 | self.text_column_name = text_column_name 76 | self.all_texts = self.table[text_column_name].to_pandas( 77 | ).tolist() 78 | if type(self.all_texts[0][0]) == str: 79 | if type(self.all_texts[0]) == str: 80 | self.all_texts = [ 81 | [self.text_preprocessor(text)] for text in self.all_texts] 82 | else: 83 | self.all_texts = ( 84 | [list(set([self.text_preprocessor(text) for text in texts])) 85 | for texts in self.all_texts] 86 | if remove_duplicate 87 | else self.all_texts 88 | ) 89 | else: # snli 90 | self.all_texts = ( 91 | [[t[1].strip() for t in texts] 92 | for texts in self.all_texts] 93 | ) 94 | else: 95 | self.all_texts = list() 96 | 97 | self.index_mapper = dict() 98 | if text_column_name != "" and not self.image_only: 99 | j = 0 100 | for i, texts in enumerate(self.all_texts): 101 | for _j in range(len(texts)): 102 | self.index_mapper[j] = (i, _j) 103 | j += 1 104 | else: 105 | for i in range(len(self.table)): 106 | self.index_mapper[i] = (i, None) 107 | # print(' Dataset length', len(self.index_mapper)) 108 | 109 | else: 110 | self.index_mapper = dict() 111 | self.all_texts = list() 112 | 113 | @property 114 | def corpus(self): 115 | return [text for texts in self.all_texts for text in texts] 116 | 117 | def __len__(self): 118 | return len(self.index_mapper) 119 | 120 | def get_raw_image(self, index, image_key="image"): 121 | index, caption_index = self.index_mapper[index] 122 | image_bytes = io.BytesIO(self.table[image_key][index].as_py()) 123 | image_bytes.seek(0) 124 | if self.clip_transform: 125 | return Image.open(image_bytes).convert("RGBA") 126 | else: 127 | return Image.open(image_bytes).convert("RGB") 128 | 129 | def get_image(self, index, image_key="image"): 130 | image = self.get_raw_image(index, image_key=image_key) 131 | image_tensor = [tr(image) for tr in self.transforms] 132 | return { 133 | "image": image_tensor, 134 | "img_index": self.index_mapper[index][0], 135 | "cap_index": self.index_mapper[index][1], 136 | "raw_index": index, 137 | } 138 | 139 | def get_false_image(self, rep, image_key="image"): 140 | """get false images for image-text matching loss""" 141 | random_index = random.randint(0, len(self.index_mapper) - 1) 142 | image = self.get_raw_image(random_index, image_key=image_key) 143 | image_tensor = [tr(image) for tr in self.transforms] 144 | return {f"false_image_{rep}": image_tensor} 145 | 146 | def get_text(self, raw_index): 147 | index, caption_index = self.index_mapper[raw_index] 148 | 149 | text = self.all_texts[index][caption_index] 150 | encoding = self.tokenizer( 151 | text, 152 | padding="max_length", 153 | truncation=True, 154 | max_length=self.max_text_len, 155 | return_special_tokens_mask=True, 156 | ) 157 | return { 158 | "text": (text, encoding), 159 | "img_index": index, 160 | "cap_index": caption_index, 161 | "raw_index": raw_index, 162 | } 163 | 164 | def get_false_text(self, rep): 165 | """get false text for image-text matching loss""" 166 | random_index = random.randint(0, len(self.index_mapper) - 1) 167 | 168 | index, caption_index = self.index_mapper[random_index] 169 | text = self.all_texts[index][caption_index] 170 | encoding = self.tokenizer( 171 | text, 172 | truncation=True, 173 | max_length=self.max_text_len, 174 | return_special_tokens_mask=True, 175 | ) 176 | return {f"false_text_{rep}": (text, encoding)} 177 | 178 | def get_suite(self, index): 179 | result = None 180 | while result is None: 181 | try: 182 | ret = dict() 183 | ret.update(self.get_image(index)) 184 | if not self.image_only: 185 | txt = self.get_text(index) 186 | ret.update( 187 | {"replica": True if txt["cap_index"] > 0 else False}) 188 | ret.update(txt) 189 | 190 | for i in range(self.draw_false_image): 191 | ret.update(self.get_false_image(i)) 192 | for i in range(self.draw_false_text): 193 | ret.update(self.get_false_text(i)) 194 | result = True 195 | except Exception as e: 196 | print( 197 | f"Error while read file idx {index} in {self.names[0]} -> {e}") 198 | index = random.randint(0, len(self.index_mapper) - 1) 199 | return ret 200 | 201 | def collate(self, batch, mlm_collator): 202 | batch_size = len(batch) 203 | keys = set([key for b in batch for key in b.keys()]) 204 | raw_dict_batch = { 205 | k: [dic[k] if k in dic else None for dic in batch] for k in keys} 206 | 207 | img_keys = [k for k in list(raw_dict_batch.keys()) if "image" in k] 208 | img_sizes = list() 209 | 210 | for img_key in img_keys: 211 | img = raw_dict_batch[img_key] 212 | img_sizes += [ii.shape for i in img if i is not None for ii in i] 213 | 214 | for size in img_sizes: 215 | assert ( 216 | len(size) == 3 217 | ), f"Collate error, an image should be in shape of (3, H, W), instead of given {size}" 218 | 219 | if len(img_keys) != 0: 220 | max_height = max([i[1] for i in img_sizes]) 221 | max_width = max([i[2] for i in img_sizes]) 222 | 223 | for img_key in img_keys: 224 | img = raw_dict_batch[img_key] 225 | view_size = len(img[0]) 226 | 227 | new_images = [ 228 | torch.zeros(batch_size, 3, max_height, max_width) 229 | for _ in range(view_size) 230 | ] 231 | 232 | for bi in range(batch_size): 233 | orig_batch = img[bi] 234 | for vi in range(view_size): 235 | if orig_batch is None: 236 | new_images[vi][bi] = None 237 | else: 238 | orig = img[bi][vi] 239 | new_images[vi][bi, :, : orig.shape[1], 240 | : orig.shape[2]] = orig 241 | 242 | raw_dict_batch[img_key] = new_images 243 | 244 | txt_keys = [k for k in list(raw_dict_batch.keys()) if "text" in k] 245 | 246 | if len(txt_keys) != 0: 247 | texts = [[d[0] for d in raw_dict_batch[txt_key]] 248 | for txt_key in txt_keys] 249 | encodings = [[d[1] for d in raw_dict_batch[txt_key]] 250 | for txt_key in txt_keys] 251 | flatten_encodings = [e for encoding in encodings for e in encoding] 252 | flatten_mlms = mlm_collator['mlm_collator'](flatten_encodings) 253 | is_sep_mlm = type( 254 | flatten_mlms) == list and not self.disable_sep_mlm 255 | flatten_mlms_all = flatten_mlms if type( 256 | flatten_mlms) == list else [flatten_mlms] 257 | 258 | dict_batch_sep_mlm = {'batch': []} 259 | for flatten_mlms in flatten_mlms_all: 260 | dict_batch = copy.deepcopy(raw_dict_batch) 261 | for i, txt_key in enumerate(txt_keys): 262 | texts, encodings = ( 263 | [d[0] for d in dict_batch[txt_key]], 264 | [d[1] for d in dict_batch[txt_key]], 265 | ) 266 | 267 | mlm_ids, mlm_labels = ( 268 | flatten_mlms["input_ids"][batch_size * 269 | (i): batch_size * (i + 1)], 270 | flatten_mlms["labels"][batch_size * 271 | (i): batch_size * (i + 1)], 272 | ) 273 | 274 | input_ids = torch.zeros_like(mlm_ids) 275 | attention_mask = torch.zeros_like(mlm_ids) 276 | for _i, encoding in enumerate(encodings): 277 | _input_ids, _attention_mask = ( 278 | torch.tensor(encoding["input_ids"]), 279 | torch.tensor(encoding["attention_mask"]), 280 | ) 281 | input_ids[_i, : len(_input_ids)] = _input_ids 282 | attention_mask[_i, : len( 283 | _attention_mask)] = _attention_mask 284 | 285 | lm_labels = input_ids[:, 1:] 286 | 287 | if 'prefixLM_collator' in mlm_collator: 288 | plm_att_mask, prefix_lm_labels = mlm_collator['prefixLM_collator']( 289 | attention_mask, input_ids) 290 | lm_labels = prefix_lm_labels[:, 1:] 291 | dict_batch[f"{txt_key}_prefixlm_masks"] = plm_att_mask 292 | 293 | dict_batch[txt_key] = texts 294 | dict_batch[f"{txt_key}_ids"] = input_ids 295 | dict_batch[f"{txt_key}_labels"] = torch.full_like( 296 | input_ids, -100) 297 | dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids 298 | dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels 299 | dict_batch[f"{txt_key}_labels_lm"] = lm_labels 300 | dict_batch[f"{txt_key}_masks"] = attention_mask 301 | dict_batch.update(self.get_flm_batch( 302 | attention_mask, input_ids, mlm_collator, txt_key)) 303 | 304 | dict_batch_sep_mlm['batch'].append(dict_batch) 305 | if not is_sep_mlm: 306 | dict_batch['is_sep_mlm'] = False 307 | return dict_batch 308 | if is_sep_mlm: 309 | dict_batch_sep_mlm['is_sep_mlm'] = True 310 | return dict_batch_sep_mlm 311 | return raw_dict_batch 312 | 313 | def get_flm_batch(self, attention_mask, input_ids, mlm_collator, txt_key): 314 | dict_batch = {} 315 | all_mask_ids = attention_mask * \ 316 | self.tokenizer.convert_tokens_to_ids('') 317 | text_len = attention_mask.sum(1) 318 | all_mask_ids[:, 0] = input_ids[:, 0] 319 | all_mask_ids[torch.arange(len( 320 | text_len)), text_len - 1] = input_ids[torch.arange(len(text_len)), text_len - 1] 321 | dict_batch[f"{txt_key}_all_masks_ids"] = all_mask_ids 322 | flm_random_ids, flm_masks, flm_label = mlm_collator['flm_collator']( 323 | attention_mask) 324 | dict_batch[f"{txt_key}_flm_mask_ids"] = flm_random_ids 325 | dict_batch[f"{txt_key}_flm_masks"] = flm_masks 326 | dict_batch[f"{txt_key}_flm_labels"] = flm_label 327 | return dict_batch 328 | -------------------------------------------------------------------------------- /flm/datasets/coco_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import io 3 | from PIL import Image 4 | 5 | 6 | # COCO Caption (with Karpathy split) Dataset 7 | class CocoCaptionKarpathyDataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | self.split = split 11 | 12 | if split == "train": 13 | names = ["coco_caption_karpathy_train", 14 | "coco_caption_karpathy_restval"] 15 | elif split == "val": 16 | names = ["coco_caption_karpathy_test"] 17 | elif split == "test": 18 | names = ["coco_caption_karpathy_test"] 19 | 20 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 21 | 22 | def __getitem__(self, index): 23 | suite = self.get_suite(index) 24 | 25 | if "test" in self.split: 26 | _index, _question_index = self.index_mapper[index] 27 | iid = self.table["image_id"][_index].as_py() 28 | iid = int(iid.split(".")[0].split("_")[-1]) 29 | suite.update({"iid": iid}) 30 | 31 | return suite 32 | -------------------------------------------------------------------------------- /flm/datasets/conceptual_caption12m_dataset.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from glob import glob 3 | from .base_dataset import BaseDataset 4 | from .conceptual_caption_dataset import ConceptualCaptionDataset 5 | import io 6 | from PIL import Image 7 | 8 | 9 | # Conceptual Caption 12M Dataset 10 | class ConceptualCaption12mDataset(ConceptualCaptionDataset): 11 | def __init__(self, *args, split="", **kwargs): 12 | assert split in ["train", "val", "test"] 13 | if split == "test": 14 | split = "val" 15 | 16 | if split == "train": 17 | names = [f"conceptual_caption12M_train_{i}" for i in range(96)] 18 | elif split == "val": 19 | # names = [f"conceptual_caption_val_{i}" for i in range(1)] 20 | names = [] 21 | 22 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 23 | -------------------------------------------------------------------------------- /flm/datasets/conceptual_caption_dataset.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .base_dataset import BaseDataset 3 | 4 | 5 | # Conceptual Caption 3M Dataset 6 | class ConceptualCaptionDataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | if split == "test": 10 | split = "val" 11 | 12 | if split == "train": 13 | names = [f"conceptual_caption_train_{i}" for i in range(29)] 14 | elif split == "val": 15 | names = [] 16 | 17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 18 | 19 | def __getitem__(self, index): 20 | return self.get_suite(index) 21 | 22 | def get_text(self, raw_index): 23 | index, caption_index = self.index_mapper[raw_index] 24 | 25 | text = self.all_texts[index][caption_index] 26 | encoding = self.tokenizer( 27 | text, 28 | padding="max_length", 29 | truncation=True, 30 | max_length=self.max_text_len, 31 | return_special_tokens_mask=True, 32 | ) 33 | return { 34 | "text": (text, encoding), 35 | "img_index": index, 36 | "cap_index": caption_index, 37 | "raw_index": raw_index, 38 | } 39 | -------------------------------------------------------------------------------- /flm/datasets/f30k_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | # Flickr30K Dataset 5 | class F30KCaptionKarpathyDataset(BaseDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | 9 | if split == "train": 10 | names = ["f30k_caption_karpathy_train", 11 | "f30k_caption_karpathy_val"] 12 | elif split == "val": 13 | names = ["f30k_caption_karpathy_test"] 14 | elif split == "test": 15 | names = ["f30k_caption_karpathy_test"] 16 | 17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 18 | 19 | def __getitem__(self, index): 20 | return self.get_suite(index) 21 | -------------------------------------------------------------------------------- /flm/datasets/laion100m_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_webdataset import WebDataset 2 | import io 3 | from PIL import Image 4 | 5 | 6 | # a 100M subset of Laion-400M Dataset 7 | class Laion100mDataset(WebDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | self.split = split 11 | 12 | if split == 'train': 13 | location = "/group/30042/public_datasets/LAION-400M/raw/data/{00001..10689}.tar" 14 | infinite_loader = False 15 | elif split == "val": 16 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar' 17 | infinite_loader = False 18 | elif split == 'test': 19 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar' 20 | infinite_loader = False 21 | super().__init__(*args, **kwargs, infinite_loader=infinite_loader, 22 | location=location, text_column_name="caption") 23 | -------------------------------------------------------------------------------- /flm/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_webdataset import WebDataset 2 | import io 3 | from PIL import Image 4 | 5 | 6 | # Laion-400M Dataset 7 | class LaionDataset(WebDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | self.split = split 11 | 12 | if split == 'train': 13 | # location = "/group/30042/public_datasets/LAION-400M/raw/data/38872.tar" 14 | # location = "/group/30042/public_datasets/LAION-400M/raw/data/{00000..42757}.tar" 15 | location = "/group/30042/public_datasets/LAION-400M/raw/data/{00001..42757}.tar" 16 | infinite_loader = True 17 | elif split == "val": 18 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar' 19 | infinite_loader = False 20 | elif split == 'test': 21 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar' 22 | infinite_loader = False 23 | super().__init__(*args, **kwargs, infinite_loader=infinite_loader, 24 | location=location, text_column_name="caption") 25 | -------------------------------------------------------------------------------- /flm/datasets/nlvr2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import sys 3 | import random 4 | 5 | 6 | # NLVR2 3M Dataset 7 | class NLVR2Dataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | self.split = split 11 | 12 | if split == "train": 13 | names = ["nlvr2_train"] 14 | elif split == "val": 15 | names = ["nlvr2_dev", "nlvr2_test1"] 16 | elif split == "test": 17 | names = ["nlvr2_dev", "nlvr2_test1"] 18 | 19 | super().__init__( 20 | *args, 21 | **kwargs, 22 | names=names, 23 | text_column_name="questions", 24 | remove_duplicate=False, 25 | ) 26 | 27 | def __getitem__(self, index): 28 | result = None 29 | while result is None: 30 | try: 31 | image_tensor_0 = self.get_image( 32 | index, image_key="image_0")["image"] 33 | image_tensor_1 = self.get_image( 34 | index, image_key="image_1")["image"] 35 | text = self.get_text(index)["text"] 36 | result = True 37 | except: 38 | print( 39 | f"error while read file idx {index} in {self.names[0]}", 40 | file=sys.stderr, 41 | ) 42 | index = random.randint(0, len(self.index_mapper) - 1) 43 | 44 | index, question_index = self.index_mapper[index] 45 | answers = self.table["answers"][index][question_index].as_py() 46 | answers = answers == "True" 47 | 48 | return { 49 | "image_0": image_tensor_0, 50 | "image_1": image_tensor_1, 51 | "text": text, 52 | "answers": answers, 53 | "table_name": self.table_names[index], 54 | } 55 | -------------------------------------------------------------------------------- /flm/datasets/sbu_caption_dataset.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from glob import glob 3 | from .base_dataset import BaseDataset 4 | import io 5 | from PIL import Image 6 | 7 | 8 | # SBU Caption 3M Dataset 9 | class SBUCaptionDataset(BaseDataset): 10 | def __init__(self, *args, split="", **kwargs): 11 | assert split in ["train", "val", "test"] 12 | if split == "test": 13 | split = "val" 14 | 15 | if split == "train": 16 | names = [f"sbu_{i}" for i in range(9)] 17 | elif split == "val": 18 | names = [] 19 | 20 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 21 | 22 | def __getitem__(self, index): 23 | return self.get_suite(index) 24 | -------------------------------------------------------------------------------- /flm/datasets/snli_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | # SNLI 3M Dataset 5 | class SNLIDataset(BaseDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["snli_train"] 12 | elif split == "val": 13 | names = ["snli_dev", "snli_test"] 14 | elif split == "test": 15 | names = ["snli_dev", "snli_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="sentences", 22 | remove_duplicate=False, 23 | ) 24 | 25 | def __getitem__(self, index): 26 | image_tensor = self.get_image(index)["image"] 27 | text = self.get_text(index)["text"] 28 | 29 | index, question_index = self.index_mapper[index] 30 | 31 | labels = self.table["labels"][index][question_index].as_py() 32 | 33 | return { 34 | "image": image_tensor, 35 | "text": text, 36 | "labels": labels, 37 | "table_name": self.table_names[index], 38 | } 39 | -------------------------------------------------------------------------------- /flm/datasets/vg_caption_dataset.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .base_dataset import BaseDataset 3 | import io 4 | from PIL import Image 5 | 6 | 7 | # Visual Genome Dataset 8 | class VisualGenomeCaptionDataset(BaseDataset): 9 | def __init__(self, *args, split="", **kwargs): 10 | assert split in ["train", "val", "test"] 11 | if split == "test": 12 | split = "val" 13 | 14 | if split == "train": 15 | names = ["vg"] 16 | elif split == "val": 17 | names = [] 18 | 19 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 20 | 21 | def __getitem__(self, index): 22 | return self.get_suite(index) 23 | -------------------------------------------------------------------------------- /flm/datasets/vqav2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | # VQAv2 Dataset 5 | class VQAv2Dataset(BaseDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["vqav2_train", "vqav2_val"] 12 | elif split == "val": 13 | names = ["vqav2_val"] 14 | elif split == "test": 15 | names = ["vqav2_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="questions", 22 | remove_duplicate=False, 23 | ) 24 | 25 | def __getitem__(self, index): 26 | image_tensor = self.get_image(index)["image"] 27 | text = self.get_text(index)["text"] 28 | 29 | index, question_index = self.index_mapper[index] 30 | qid = self.table["question_id"][index][question_index].as_py() 31 | 32 | if self.split != "test": 33 | answers = self.table["answers"][index][question_index].as_py() 34 | labels = self.table["answer_labels"][index][question_index].as_py() 35 | scores = self.table["answer_scores"][index][question_index].as_py() 36 | else: 37 | answers = list() 38 | labels = list() 39 | scores = list() 40 | 41 | return { 42 | "image": image_tensor, 43 | "text": text, 44 | "vqa_answer": answers, 45 | "vqa_labels": labels, 46 | "vqa_scores": scores, 47 | "qid": qid, 48 | } 49 | -------------------------------------------------------------------------------- /flm/gadgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/gadgets/__init__.py -------------------------------------------------------------------------------- /flm/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.metrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | """log the accuracy metric""" 7 | 8 | def __init__(self, dist_sync_on_step=False): 9 | super().__init__(dist_sync_on_step=dist_sync_on_step) 10 | self.add_state("correct", default=torch.tensor( 11 | 0.0), dist_reduce_fx="sum") 12 | self.add_state("total", default=torch.tensor( 13 | 0.0), dist_reduce_fx="sum") 14 | 15 | def update(self, logits, target, ignore_index=-100): 16 | logits, target = ( 17 | logits.detach().to(self.correct.device), 18 | target.detach().to(self.correct.device), 19 | ) 20 | preds = logits.argmax(dim=-1) 21 | preds = preds[target != ignore_index] 22 | target = target[target != ignore_index] 23 | if target.numel() == 0: 24 | return 1 25 | 26 | assert preds.shape == target.shape 27 | 28 | self.correct += torch.sum(preds == target) 29 | self.total += target.numel() 30 | 31 | def compute(self): 32 | return self.correct / self.total 33 | 34 | 35 | class Scalar(Metric): 36 | def __init__(self, dist_sync_on_step=False): 37 | super().__init__(dist_sync_on_step=dist_sync_on_step) 38 | self.add_state("scalar", default=torch.tensor( 39 | 0.0), dist_reduce_fx="sum") 40 | self.add_state("total", default=torch.tensor( 41 | 0.0), dist_reduce_fx="sum") 42 | 43 | def update(self, scalar): 44 | if isinstance(scalar, torch.Tensor): 45 | scalar = scalar.detach().to(self.scalar.device) 46 | else: 47 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 48 | self.scalar += scalar 49 | self.total += 1 50 | 51 | def compute(self): 52 | return self.scalar / self.total 53 | 54 | 55 | class VQAScore(Metric): 56 | """calculate and log the VQA accuracy""" 57 | 58 | def __init__(self, dist_sync_on_step=False): 59 | super().__init__(dist_sync_on_step=dist_sync_on_step) 60 | self.add_state("score", default=torch.tensor( 61 | 0.0), dist_reduce_fx="sum") 62 | self.add_state("total", default=torch.tensor( 63 | 0.0), dist_reduce_fx="sum") 64 | 65 | def update(self, logits, target): 66 | logits, target = ( 67 | logits.detach().float().to(self.score.device), 68 | target.detach().float().to(self.score.device), 69 | ) 70 | logits = torch.max(logits, 1)[1] 71 | one_hots = torch.zeros(*target.size()).to(target) 72 | one_hots.scatter_(1, logits.view(-1, 1), 1) 73 | scores = one_hots * target 74 | 75 | self.score += scores.sum() 76 | self.total += len(logits) 77 | 78 | def compute(self): 79 | return self.score / self.total 80 | -------------------------------------------------------------------------------- /flm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .flm_module import FLMTransformerSS 2 | -------------------------------------------------------------------------------- /flm/modules/clip_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # CLIP 3 | # Modified from https://github.com/openai/CLIP/blob/main/clip/model.py 4 | # Copyright (c) OpenAI 5 | # ------------------------------------------------------------------------ 6 | 7 | import warnings 8 | from tqdm import tqdm 9 | import urllib 10 | import hashlib 11 | import os 12 | from collections import OrderedDict 13 | from typing import Tuple, Union 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | 19 | 20 | class LayerNorm(nn.LayerNorm): 21 | """Subclass torch's LayerNorm to handle fp16.""" 22 | 23 | def forward(self, x: torch.Tensor): 24 | orig_type = x.dtype 25 | ret = super().forward(x.type(torch.float32)) 26 | return ret.type(orig_type) 27 | 28 | 29 | class QuickGELU(nn.Module): 30 | def forward(self, x: torch.Tensor): 31 | return x * torch.sigmoid(1.702 * x) 32 | 33 | 34 | class ResidualAttentionBlock(nn.Module): 35 | def __init__(self, d_model: int, 36 | n_head: int, 37 | attn_mask: torch.Tensor = None): 38 | super().__init__() 39 | 40 | self.attn = nn.MultiheadAttention(d_model, n_head) 41 | self.ln_1 = LayerNorm(d_model) 42 | self.mlp = nn.Sequential(OrderedDict([ 43 | ("c_fc", nn.Linear(d_model, d_model * 4)), 44 | ("gelu", QuickGELU()), 45 | ("c_proj", nn.Linear(d_model * 4, d_model)) 46 | ])) 47 | self.ln_2 = LayerNorm(d_model) 48 | self.attn_mask = attn_mask 49 | 50 | def attention(self, x: torch.Tensor, x_mask: torch.Tensor): 51 | if x_mask is not None: 52 | x_mask = x_mask.to(dtype=torch.bool, device=x.device) 53 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ 54 | if self.attn_mask is not None else None 55 | return self.attn(x, x, x, 56 | need_weights=False, 57 | attn_mask=self.attn_mask, 58 | key_padding_mask=x_mask)[0] 59 | 60 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): 61 | x = x + self.attention(self.ln_1(x), x_mask) 62 | x = x + self.mlp(self.ln_2(x)) 63 | return x 64 | 65 | 66 | class Transformer(nn.Module): 67 | def __init__(self, width: int, layers: int, 68 | heads: int, attn_mask: torch.Tensor = None): 69 | super().__init__() 70 | self.width = width 71 | self.layers = layers 72 | self.resblocks = nn.Sequential( 73 | *[ResidualAttentionBlock(width, heads, attn_mask) 74 | for _ in range(layers-1)]) 75 | 76 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): 77 | for block in self.resblocks: 78 | x = block(x, x_mask) 79 | return x 80 | 81 | 82 | class VisualTransformer(nn.Module): 83 | def __init__(self, input_resolution: int, patch_size: int, width: int, 84 | layers: int, heads: int, output_dim: int, 85 | resolution_after: int): 86 | super().__init__() 87 | self.input_resolution = input_resolution 88 | self.output_dim = output_dim 89 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, 90 | kernel_size=patch_size, stride=patch_size, 91 | bias=False) 92 | 93 | scale = width ** -0.5 94 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 95 | self.positional_embedding = nn.Parameter( 96 | scale * torch.randn( 97 | (resolution_after // patch_size) ** 2 + 1, width)) 98 | self.ln_pre = LayerNorm(width) 99 | 100 | self.transformer = Transformer(width, layers, heads) 101 | self.ln_post = LayerNorm(width) 102 | 103 | def forward(self, x: torch.Tensor, x_mask): 104 | x = self.conv1(x) # shape = [*, width, grid, grid] 105 | # shape = [*, width, grid ** 2] 106 | x = x.reshape(x.shape[0], x.shape[1], -1) 107 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 108 | t = self.class_embedding.to( 109 | x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, 110 | device=x.device) 111 | x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width] 112 | x = x + self.positional_embedding.to(x.dtype) 113 | x = self.ln_pre(x) 114 | 115 | x = x.permute(1, 0, 2) # NLD -> LND 116 | x = self.transformer(x, x_mask) 117 | x = x.permute(1, 0, 2) # LND -> NLD 118 | 119 | x = self.ln_post(x) 120 | 121 | return x 122 | 123 | 124 | class CLIP(nn.Module): 125 | def __init__(self, 126 | embed_dim: int, 127 | # vision 128 | image_resolution: int, 129 | vision_layers: Union[Tuple[int, int, int, int], int], 130 | vision_width: int, 131 | vision_patch_size: int, 132 | # text 133 | context_length: int, 134 | vocab_size: int, 135 | transformer_width: int, 136 | transformer_heads: int, 137 | transformer_layers: int, 138 | resolution_after=224, 139 | ): 140 | super().__init__() 141 | 142 | self.context_length = context_length 143 | 144 | vision_heads = vision_width // 64 145 | self.visual = VisualTransformer( 146 | input_resolution=image_resolution, 147 | patch_size=vision_patch_size, 148 | width=vision_width, 149 | layers=vision_layers, 150 | heads=vision_heads, 151 | output_dim=embed_dim, 152 | resolution_after=resolution_after, 153 | ) 154 | 155 | self.vocab_size = vocab_size 156 | self.positional_embedding = nn.Parameter( 157 | torch.empty(self.context_length, transformer_width)) 158 | self.ln_final = LayerNorm(transformer_width) 159 | 160 | self.initialize_parameters() 161 | 162 | def initialize_parameters(self): 163 | nn.init.normal_(self.positional_embedding, std=0.01) 164 | 165 | proj_std = (self.visual.transformer.width ** -0.5) * \ 166 | ((2 * self.visual.transformer.layers) ** -0.5) 167 | attn_std = self.visual.transformer.width ** -0.5 168 | fc_std = (2 * self.visual.transformer.width) ** -0.5 169 | for block in self.visual.transformer.resblocks: 170 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 171 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 172 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 173 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 174 | 175 | @property 176 | def dtype(self): 177 | return self.visual.conv1.weight.dtype 178 | 179 | def forward(self, image, image_mask=None): 180 | return self.visual(image.type(self.dtype), image_mask) 181 | 182 | 183 | _MODELS = { 184 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 185 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 186 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 187 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 188 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 189 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 190 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 191 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 192 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 193 | } 194 | 195 | 196 | def _download(url: str, root: str = os.path.expanduser(".cache/clip")): 197 | os.makedirs(root, exist_ok=True) 198 | filename = os.path.basename(url) 199 | 200 | expected_sha256 = url.split("/")[-2] 201 | download_target = os.path.join(root, filename) 202 | 203 | if os.path.exists(download_target) and not os.path.isfile(download_target): 204 | raise RuntimeError( 205 | f"{download_target} exists and is not a regular file") 206 | 207 | if os.path.isfile(download_target): 208 | if hashlib.sha256( 209 | open(download_target, "rb").read()).hexdigest() \ 210 | == expected_sha256: 211 | return download_target 212 | else: 213 | warnings.warn( 214 | f"{download_target} exists, but the SHA256 checksum does not \ 215 | match; re-downloading the file") 216 | 217 | with urllib.request.urlopen(url) as source, \ 218 | open(download_target, "wb") as output: 219 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, 220 | unit='iB', unit_scale=True) as loop: 221 | while True: 222 | buffer = source.read(8192) 223 | if not buffer: 224 | break 225 | 226 | output.write(buffer) 227 | loop.update(len(buffer)) 228 | 229 | if hashlib.sha256( 230 | open(download_target, "rb").read()).hexdigest() != expected_sha256: 231 | raise RuntimeError( 232 | "Model has been downloaded \ 233 | but the SHA256 checksum does not not match") 234 | 235 | return download_target 236 | 237 | 238 | def adapt_position_encoding(model, patch_size=32, after=384, 239 | suffix='visual.positional_embedding'): 240 | keys = [k for k in model if k.endswith(suffix)] 241 | assert len(keys) == 1 242 | key = keys[0] 243 | origin_pos_embed = model[key] 244 | origin_dim2 = False 245 | if len(origin_pos_embed.shape) == 2: 246 | origin_dim2 = True 247 | origin_pos_embed = origin_pos_embed.unsqueeze(0) 248 | grid_before = int(np.sqrt(origin_pos_embed.shape[1] - 1)) 249 | before = int(grid_before*patch_size) 250 | assert (before % patch_size) == 0 251 | grid_after = after // patch_size 252 | assert (after % patch_size) == 0 253 | embed_dim = origin_pos_embed.shape[-1] 254 | 255 | pos_embed = origin_pos_embed[0, 1:, :].reshape( 256 | (grid_before, grid_before, embed_dim)) 257 | new_size = (grid_after, grid_after) 258 | pos_embed = torch.nn.functional.interpolate(pos_embed.permute( 259 | (2, 0, 1)).unsqueeze(0), size=new_size, mode='bicubic') 260 | pos_embed = pos_embed.squeeze(0).permute( 261 | (1, 2, 0)).reshape((-1, embed_dim)) 262 | pos_embed = torch.cat( 263 | (origin_pos_embed[0, 0:1, :], pos_embed), dim=0).unsqueeze(0) 264 | assert pos_embed.shape == (1, grid_after * grid_after + 1, embed_dim) 265 | if origin_dim2: 266 | assert pos_embed.shape[0] == 1 267 | pos_embed = pos_embed.squeeze(0) 268 | model[key] = pos_embed 269 | return model 270 | 271 | 272 | def build_model(name, resolution_after=224): 273 | if name in _MODELS: 274 | model_path = _download(_MODELS[name]) 275 | elif os.path.isfile(name): 276 | model_path = name 277 | else: 278 | raise RuntimeError(f"Model {name} not found; \ 279 | available models = {available_models()}") 280 | try: 281 | model = torch.jit.load(model_path, map_location="cpu") 282 | state_dict = None 283 | except RuntimeError: 284 | if jit: 285 | warnings.warn( 286 | f"File {model_path} is not a JIT archive. \ 287 | Loading as a state dict instead") 288 | jit = False 289 | state_dict = torch.load(model_path, map_location="cpu") 290 | state_dict = state_dict or model.state_dict() 291 | 292 | vision_width = state_dict["visual.conv1.weight"].shape[0] 293 | vision_layers = len([k for k in state_dict.keys() if k.startswith( 294 | "visual.") and k.endswith(".attn.in_proj_weight")]) 295 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 296 | grid_size = round( 297 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 298 | image_resolution = vision_patch_size * grid_size 299 | 300 | embed_dim = state_dict["text_projection"].shape[1] 301 | context_length = state_dict["positional_embedding"].shape[0] 302 | vocab_size = state_dict["token_embedding.weight"].shape[0] 303 | transformer_width = state_dict["ln_final.weight"].shape[0] 304 | transformer_heads = transformer_width // 64 305 | transformer_layers = len(set( 306 | k.split(".")[2] for k in state_dict 307 | if k.startswith("transformer.resblocks"))) 308 | 309 | model = CLIP( 310 | embed_dim, 311 | image_resolution, vision_layers, vision_width, vision_patch_size, 312 | context_length, vocab_size, transformer_width, transformer_heads, 313 | transformer_layers, resolution_after, 314 | ) 315 | 316 | for key in ["input_resolution", "context_length", "vocab_size"]: 317 | if key in state_dict: 318 | del state_dict[key] 319 | 320 | model_dict = model.state_dict() 321 | pretrained_dict = state_dict 322 | if resolution_after != image_resolution: 323 | pretrained_dict = adapt_position_encoding( 324 | pretrained_dict, 325 | after=resolution_after, 326 | patch_size=vision_patch_size) 327 | # 1. filter out unnecessary keys 328 | pretrained_dict = {k: v for k, 329 | v in pretrained_dict.items() if k in model_dict} 330 | # 2. overwrite entries in the existing state dict 331 | model_dict.update(pretrained_dict) 332 | # 3. load the new state dict 333 | model.load_state_dict(model_dict) 334 | return model 335 | -------------------------------------------------------------------------------- /flm/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import torch 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert _LOCAL_PROCESS_GROUP is not None 49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 50 | 51 | 52 | def get_local_size() -> int: 53 | """ 54 | Returns: 55 | The size of the per-machine process group, 56 | i.e. the number of processes per machine. 57 | """ 58 | if not dist.is_available(): 59 | return 1 60 | if not dist.is_initialized(): 61 | return 1 62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 63 | 64 | 65 | def is_main_process() -> bool: 66 | return get_rank() == 0 67 | 68 | 69 | def synchronize(): 70 | """ 71 | Helper function to synchronize (barrier) among all processes when 72 | using distributed training 73 | """ 74 | if not dist.is_available(): 75 | return 76 | if not dist.is_initialized(): 77 | return 78 | world_size = dist.get_world_size() 79 | if world_size == 1: 80 | return 81 | dist.barrier() 82 | 83 | 84 | @functools.lru_cache() 85 | def _get_global_gloo_group(): 86 | """ 87 | Return a process group based on gloo backend, containing all the ranks 88 | The result is cached. 89 | """ 90 | if dist.get_backend() == "nccl": 91 | return dist.new_group(backend="gloo") 92 | else: 93 | return dist.group.WORLD 94 | 95 | 96 | def _serialize_to_tensor(data, group): 97 | backend = dist.get_backend(group) 98 | assert backend in ["gloo", "nccl"] 99 | device = torch.device("cpu" if backend == "gloo" else "cuda") 100 | 101 | buffer = pickle.dumps(data) 102 | if len(buffer) > 1024 ** 3: 103 | logger = logging.getLogger(__name__) 104 | logger.warning( 105 | "Rank {} trying to all-gather {:.2f} GB of data on device\ 106 | {}".format(get_rank(), len(buffer) / (1024 ** 3), device) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor( 124 | [tensor.numel()], dtype=torch.int64, device=tensor.device) 125 | size_list = [ 126 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 127 | for _ in range(world_size) 128 | ] 129 | dist.all_gather(size_list, local_size, group=group) 130 | size_list = [int(size.item()) for size in size_list] 131 | 132 | max_size = max(size_list) 133 | 134 | # we pad the tensor because torch all_gather does not support 135 | # gathering tensors of different shapes 136 | if local_size != max_size: 137 | padding = torch.zeros( 138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 139 | ) 140 | tensor = torch.cat((tensor, padding), dim=0) 141 | return size_list, tensor 142 | 143 | 144 | def all_gather(data, group=None): 145 | """ 146 | Run all_gather on arbitrary picklable data (not necessarily tensors). 147 | 148 | Args: 149 | data: any picklable object 150 | group: a torch process group. By default, will use a group which 151 | contains all ranks on gloo backend. 152 | 153 | Returns: 154 | list[data]: list of data gathered from each rank 155 | """ 156 | if get_world_size() == 1: 157 | return [data] 158 | if group is None: 159 | group = _get_global_gloo_group() 160 | if dist.get_world_size(group) == 1: 161 | return [data] 162 | 163 | tensor = _serialize_to_tensor(data, group) 164 | 165 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 166 | max_size = max(size_list) 167 | 168 | # receiving Tensor from all ranks 169 | tensor_list = [ 170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 171 | for _ in size_list 172 | ] 173 | dist.all_gather(tensor_list, tensor, group=group) 174 | 175 | data_list = [] 176 | for size, tensor in zip(size_list, tensor_list): 177 | buffer = tensor.cpu().numpy().tobytes()[:size] 178 | data_list.append(pickle.loads(buffer)) 179 | 180 | return data_list 181 | 182 | 183 | def gather(data, dst=0, group=None): 184 | """ 185 | Run gather on arbitrary picklable data (not necessarily tensors). 186 | 187 | Args: 188 | data: any picklable object 189 | dst (int): destination rank 190 | group: a torch process group. By default, will use a group which 191 | contains all ranks on gloo backend. 192 | 193 | Returns: 194 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 195 | an empty list. 196 | """ 197 | if get_world_size() == 1: 198 | return [data] 199 | if group is None: 200 | group = _get_global_gloo_group() 201 | if dist.get_world_size(group=group) == 1: 202 | return [data] 203 | rank = dist.get_rank(group=group) 204 | 205 | tensor = _serialize_to_tensor(data, group) 206 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 207 | 208 | # receiving Tensor from all ranks 209 | if rank == dst: 210 | max_size = max(size_list) 211 | tensor_list = [ 212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 213 | for _ in size_list 214 | ] 215 | dist.gather(tensor, tensor_list, dst=dst, group=group) 216 | 217 | data_list = [] 218 | for size, tensor in zip(size_list, tensor_list): 219 | buffer = tensor.cpu().numpy().tobytes()[:size] 220 | data_list.append(pickle.loads(buffer)) 221 | return data_list 222 | else: 223 | dist.gather(tensor, [], dst=dst, group=group) 224 | return [] 225 | 226 | 227 | def shared_random_seed(): 228 | """ 229 | Returns: 230 | int: a random number that is the same across all workers. 231 | If workers need a shared RNG, they can use this shared seed to 232 | create one. 233 | 234 | All workers must call this function, otherwise it will deadlock. 235 | """ 236 | ints = np.random.randint(2 ** 31) 237 | all_ints = all_gather(ints) 238 | return all_ints[0] 239 | 240 | 241 | def reduce_dict(input_dict, average=True): 242 | """ 243 | Reduce the values in the dictionary from all processes so that process with rank 244 | 0 has the reduced results. 245 | 246 | Args: 247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 248 | average (bool): whether to do average or sum 249 | 250 | Returns: 251 | a dict with the same keys as input_dict, after reduction. 252 | """ 253 | world_size = get_world_size() 254 | if world_size < 2: 255 | return input_dict 256 | with torch.no_grad(): 257 | names = [] 258 | values = [] 259 | # sort the keys so that they are consistent across processes 260 | for k in sorted(input_dict.keys()): 261 | names.append(k) 262 | values.append(input_dict[k]) 263 | values = torch.stack(values, dim=0) 264 | dist.reduce(values, dst=0) 265 | if dist.get_rank() == 0 and average: 266 | # only main process gets accumulated, so only divide by 267 | # world_size in this case 268 | values /= world_size 269 | reduced_dict = {k: v for k, v in zip(names, values)} 270 | return reduced_dict 271 | -------------------------------------------------------------------------------- /flm/modules/flm_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_corr_bi_attention_mask(mask, mask_r, span_corr_rate=0): 6 | """prepare the attention mask in reconstrctor""" 7 | bs, L, M, N = mask.shape 8 | org_bi_mask = torch.cat([mask, mask_r], dim=-1) 9 | bi_mask = org_bi_mask.detach().clone() 10 | bi_mask[:, :, torch.arange(1, N), torch.arange(1, N)] = -10000. 11 | bi_mask[:, :, torch.arange( 12 | 1, N), N + torch.arange(1, N)] = -10000. # [bs, L, L] 13 | text_len = (bi_mask != -10000.).sum(dim=3) + 1 14 | text_len[:, :, 0] = 1 15 | 16 | if span_corr_rate > 0: 17 | add_corr_rate = torch.maximum(torch.zeros_like( 18 | text_len), (text_len * span_corr_rate - 1.)/(text_len - 1 + 1e-5)) 19 | mask_num = torch.distributions.Binomial( 20 | text_len.float() - 1, add_corr_rate).sample().int() 21 | start_bias = mask_num // 2 + torch.bernoulli(mask_num/2 - mask_num//2) 22 | angle = torch.arange(0, N, device=mask.device).long() 23 | start = torch.maximum(angle - start_bias.long(), 0*angle) 24 | end = torch.minimum(start + N + mask_num, start.new_tensor(2*N-1)) 25 | start_step = angle[None, None].repeat(bs, L, 1) - start 26 | for i in range(torch.max(start_step[:, :, 1:])): 27 | bi_mask[torch.arange(bs).reshape(bs, 1, 1).repeat(1, L, N), torch.arange(L).reshape(1, L, 1).repeat( 28 | bs, 1, N), angle[None, None].repeat(bs, L, 1), torch.minimum(start+i, angle[None, None])] = -10000. 29 | 30 | end_step = end - angle[None, None].repeat(bs, L, 1) - N 31 | for i in range(torch.max(end_step[:, :, 1:])): 32 | bi_mask[torch.arange(bs).reshape(bs, 1, 1).repeat(1, L, N), torch.arange(L).reshape(1, L, 1).repeat( 33 | bs, 1, N), angle[None, None].repeat(bs, L, 1), torch.maximum(end-i, N + angle[None, None])] = -10000. 34 | return torch.cat([org_bi_mask[:, :, :1], bi_mask[:, :, 1:]], dim=2) 35 | -------------------------------------------------------------------------------- /flm/modules/heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 6 | 7 | 8 | class Pooler(nn.Module): 9 | def __init__(self, hidden_size): 10 | super().__init__() 11 | self.dense = nn.Linear(hidden_size, hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | first_token_tensor = hidden_states[:, 0] 16 | pooled_output = self.dense(first_token_tensor) 17 | pooled_output = self.activation(pooled_output) 18 | return pooled_output 19 | 20 | 21 | class ITMHead(nn.Module): 22 | def __init__(self, hidden_size): 23 | super().__init__() 24 | self.fc = nn.Linear(hidden_size, 2) 25 | 26 | def forward(self, x): 27 | x = self.fc(x) 28 | return x 29 | 30 | 31 | class MLMHead(nn.Module): 32 | def __init__(self, config, weight=None): 33 | super().__init__() 34 | self.transform = BertPredictionHeadTransform(config) 35 | self.decoder = nn.Linear( 36 | config.hidden_size, config.vocab_size, bias=False) 37 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 38 | if weight is not None: 39 | self.decoder.weight = weight 40 | 41 | def forward(self, x): 42 | x = self.transform(x) 43 | x = self.decoder(x) + self.bias 44 | return x 45 | -------------------------------------------------------------------------------- /flm/modules/meter_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from transformers.optimization import AdamW 5 | from transformers import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | get_cosine_schedule_with_warmup, 8 | ) 9 | from .dist_utils import all_gather 10 | from .objectives import compute_irtr_recall, compute_caption 11 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar 12 | 13 | 14 | def set_metrics(pl_module): 15 | for split in ["train", "val"]: 16 | for k, v in pl_module.hparams.config["loss_names"].items(): 17 | if v <= 0: 18 | continue 19 | if k == "vqa": 20 | setattr(pl_module, f"{split}_vqa_score", VQAScore()) 21 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 22 | elif k == "nlvr2": 23 | if split == "train": 24 | setattr(pl_module, f"train_{k}_accuracy", Accuracy()) 25 | setattr(pl_module, f"train_{k}_loss", Scalar()) 26 | else: 27 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy()) 28 | setattr(pl_module, f"dev_{k}_loss", Scalar()) 29 | setattr(pl_module, f"test_{k}_accuracy", Accuracy()) 30 | setattr(pl_module, f"test_{k}_loss", Scalar()) 31 | elif k == "irtr": 32 | setattr(pl_module, f"{split}_irtr_loss", Scalar()) 33 | elif k == "mppd" or k == "mpfr": 34 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 35 | elif k == "itm": 36 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 37 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 38 | else: 39 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 40 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 41 | 42 | if 'flm' in k and pl_module.hparams.config["enable_flm_aux_lm_loss"]: 43 | setattr(pl_module, f"{split}_flma1_accuracy", Accuracy()) 44 | setattr(pl_module, f"{split}_flma2_accuracy", Accuracy()) 45 | setattr(pl_module, f"{split}_flma1_loss", Scalar()) 46 | setattr(pl_module, f"{split}_flma2_loss", Scalar()) 47 | 48 | 49 | def epoch_wrapup(pl_module): 50 | phase = "train" if pl_module.training else "val" 51 | the_metric = 0 52 | if pl_module.hparams.config["get_caption_metric"] and not pl_module.training: 53 | b4, m, c, s = compute_caption(pl_module) 54 | pl_module.logger.experiment.add_scalar( 55 | "caption/b4", b4, pl_module.global_step 56 | ) 57 | pl_module.logger.experiment.add_scalar( 58 | "caption/meter", m, pl_module.global_step 59 | ) 60 | pl_module.logger.experiment.add_scalar( 61 | "caption/cider", c, pl_module.global_step 62 | ) 63 | pl_module.logger.experiment.add_scalar( 64 | "caption/spice", s, pl_module.global_step 65 | ) 66 | the_metric += c + m 67 | 68 | # if pl_module.hparams.config["get_mlm_caption_metric"] and not pl_module.training: 69 | # b4, m, c, s = compute_mlm_caption(pl_module) 70 | # pl_module.logger.experiment.add_scalar( 71 | # "caption/b4", b4, pl_module.global_step 72 | # ) 73 | # pl_module.logger.experiment.add_scalar( 74 | # "caption/meter", m, pl_module.global_step 75 | # ) 76 | # pl_module.logger.experiment.add_scalar( 77 | # "caption/cider", c, pl_module.global_step 78 | # ) 79 | # pl_module.logger.experiment.add_scalar( 80 | # "caption/spice", s, pl_module.global_step 81 | # ) 82 | # the_metric += c + m 83 | 84 | if pl_module.hparams.config["get_recall_metric"] and not pl_module.training: 85 | (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, 86 | tr_r10) = compute_irtr_recall(pl_module) 87 | print((ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10), pl_module.global_step) 88 | pl_module.logger.experiment.add_scalar( 89 | "recalls/ir_r1", ir_r1, pl_module.global_step 90 | ) 91 | pl_module.logger.experiment.add_scalar( 92 | "recalls/ir_r5", ir_r5, pl_module.global_step 93 | ) 94 | pl_module.logger.experiment.add_scalar( 95 | "recalls/ir_r10", ir_r10, pl_module.global_step 96 | ) 97 | pl_module.logger.experiment.add_scalar( 98 | "recalls/tr_r1", tr_r1, pl_module.global_step 99 | ) 100 | pl_module.logger.experiment.add_scalar( 101 | "recalls/tr_r5", tr_r5, pl_module.global_step 102 | ) 103 | pl_module.logger.experiment.add_scalar( 104 | "recalls/tr_r10", tr_r10, pl_module.global_step 105 | ) 106 | the_metric += ir_r1.item() + tr_r1.item() 107 | 108 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 109 | if v <= 0: 110 | continue 111 | 112 | value = 0 113 | 114 | if loss_name == "vqa": 115 | value = getattr(pl_module, f"{phase}_{loss_name}_score").compute() 116 | pl_module.log(f"{loss_name}/{phase}/score_epoch", value) 117 | getattr(pl_module, f"{phase}_{loss_name}_score").reset() 118 | pl_module.log( 119 | f"{loss_name}/{phase}/loss_epoch", 120 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 121 | ) 122 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 123 | elif loss_name == "nlvr2" or loss_name == 'snli': 124 | if phase == "train": 125 | value = getattr( 126 | pl_module, f"train_{loss_name}_accuracy").compute() 127 | pl_module.log(f"{loss_name}/train/accuracy_epoch", value) 128 | getattr(pl_module, f"train_{loss_name}_accuracy").reset() 129 | pl_module.log( 130 | f"{loss_name}/train/loss_epoch", 131 | getattr(pl_module, f"train_{loss_name}_loss").compute(), 132 | ) 133 | getattr(pl_module, f"train_{loss_name}_loss").reset() 134 | else: 135 | value = getattr( 136 | pl_module, f"test_{loss_name}_accuracy").compute() 137 | pl_module.log(f"{loss_name}/test/accuracy_epoch", value) 138 | getattr(pl_module, f"test_{loss_name}_accuracy").reset() 139 | pl_module.log( 140 | f"{loss_name}/test/loss_epoch", 141 | getattr(pl_module, f"test_{loss_name}_loss").compute(), 142 | ) 143 | getattr(pl_module, f"test_{loss_name}_loss").reset() 144 | 145 | value = getattr( 146 | pl_module, f"dev_{loss_name}_accuracy").compute() 147 | pl_module.log(f"{loss_name}/dev/accuracy_epoch", value) 148 | getattr(pl_module, f"dev_{loss_name}_accuracy").reset() 149 | pl_module.log( 150 | f"{loss_name}/dev/loss_epoch", 151 | getattr(pl_module, f"dev_{loss_name}_loss").compute(), 152 | ) 153 | getattr(pl_module, f"dev_{loss_name}_loss").reset() 154 | elif loss_name == 'wino': 155 | if phase == 'train': 156 | pass 157 | else: 158 | value = getattr( 159 | pl_module, f"test_{loss_name}_accuracy_img").compute() 160 | value_text = getattr( 161 | pl_module, f"test_{loss_name}_accuracy_text").compute() 162 | pl_module.log(f"{loss_name}/test/accuracy_img_epoch", value) 163 | pl_module.log( 164 | f"{loss_name}/test/accuracy_text_epoch", value_text) 165 | getattr(pl_module, f"test_{loss_name}_accuracy_img").reset() 166 | getattr(pl_module, f"test_{loss_name}_accuracy_text").reset() 167 | 168 | elif loss_name == "irtr": 169 | pl_module.log( 170 | f"{loss_name}/{phase}/irtr_loss_epoch", 171 | getattr(pl_module, f"{phase}_irtr_loss").compute(), 172 | ) 173 | getattr(pl_module, f"{phase}_irtr_loss").reset() 174 | elif loss_name == "mppd" or loss_name == "mpfr": 175 | pl_module.log( 176 | f"{loss_name}/{phase}/loss_epoch", 177 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 178 | ) 179 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 180 | elif loss_name == "itm": 181 | value = getattr( 182 | pl_module, f"{phase}_{loss_name}_accuracy").compute() 183 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 184 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 185 | pl_module.log( 186 | f"{loss_name}/{phase}/loss_epoch", 187 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 188 | ) 189 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 190 | else: 191 | value = getattr( 192 | pl_module, f"{phase}_{loss_name}_accuracy").compute() 193 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 194 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 195 | pl_module.log( 196 | f"{loss_name}/{phase}/loss_epoch", 197 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 198 | ) 199 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 200 | 201 | the_metric += value 202 | 203 | pl_module.log(f"{phase}/the_metric", the_metric) 204 | 205 | 206 | def check_non_acc_grad(pl_module): 207 | if pl_module.token_type_embeddings.weight.grad is None: 208 | return True 209 | else: 210 | grad = pl_module.token_type_embeddings.weight.grad 211 | return (grad.sum() == 0).item() 212 | 213 | 214 | def set_task(pl_module): 215 | pl_module.current_tasks = [ 216 | k for k, v in pl_module.hparams.config["loss_names"].items() if v > 0 217 | ] 218 | return 219 | 220 | 221 | def get_grouped_parameters(pl_module, no_decay, head_names, cross_modal_names, 222 | wd, lr, lr_mult_head, lr_mult_cross_modal): 223 | optimizer_grouped_parameters = [ 224 | { 225 | "params": [ 226 | p 227 | for n, p in pl_module.named_parameters() 228 | if not any(nd in n for nd in no_decay) 229 | and not any(bb in n for bb in head_names) 230 | and not any(ht in n for ht in cross_modal_names) 231 | ], 232 | "weight_decay": wd, 233 | "lr": lr, 234 | }, 235 | { 236 | "params": [ 237 | p 238 | for n, p in pl_module.named_parameters() 239 | if any(nd in n for nd in no_decay) 240 | and not any(bb in n for bb in head_names) 241 | and not any(ht in n for ht in cross_modal_names) 242 | ], 243 | "weight_decay": 0.0, 244 | "lr": lr, 245 | }, 246 | { 247 | "params": [ 248 | p 249 | for n, p in pl_module.named_parameters() 250 | if not any(nd in n for nd in no_decay) 251 | and any(bb in n for bb in head_names) 252 | and not any(ht in n for ht in cross_modal_names) 253 | ], 254 | "weight_decay": wd, 255 | "lr": lr * lr_mult_head, 256 | }, 257 | { 258 | "params": [ 259 | p 260 | for n, p in pl_module.named_parameters() 261 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names) 262 | and not any(ht in n for ht in cross_modal_names) 263 | ], 264 | "weight_decay": 0.0, 265 | "lr": lr * lr_mult_head, 266 | }, 267 | { 268 | "params": [ 269 | p 270 | for n, p in pl_module.named_parameters() 271 | if not any(nd in n for nd in no_decay) 272 | and not any(bb in n for bb in head_names) 273 | and any(ht in n for ht in cross_modal_names) 274 | ], 275 | "weight_decay": wd, 276 | "lr": lr * lr_mult_cross_modal, 277 | }, 278 | { 279 | "params": [ 280 | p 281 | for n, p in pl_module.named_parameters() 282 | if any(nd in n for nd in no_decay) 283 | and not any(bb in n for bb in head_names) 284 | and any(ht in n for ht in cross_modal_names) 285 | ], 286 | "weight_decay": 0.0, 287 | "lr": lr * lr_mult_cross_modal, 288 | }, 289 | ] 290 | return optimizer_grouped_parameters 291 | 292 | 293 | def set_schedule(pl_module): 294 | lr = pl_module.hparams.config["learning_rate"] 295 | wd = pl_module.hparams.config["weight_decay"] 296 | 297 | no_decay = [ 298 | "bias", 299 | "LayerNorm.bias", 300 | "LayerNorm.weight", 301 | "norm.bias", 302 | "norm.weight", 303 | "norm1.bias", 304 | "norm1.weight", 305 | "norm2.bias", 306 | "norm2.weight", 307 | ] 308 | head_names = ["vqa_classifier", "nlvr2_classifier", "mlm_score", "itm_score", 309 | "snli_classifier", "lm_score", "flm_score", "cl_image", "cl_text"] 310 | cross_modal_names = ['cross_modal', 'fusion_layers'] 311 | lr_mult_head = pl_module.hparams.config["lr_mult_head"] 312 | lr_mult_cross_modal = pl_module.hparams.config["lr_mult_cross_modal"] 313 | end_lr = pl_module.hparams.config["end_lr"] 314 | decay_power = pl_module.hparams.config["decay_power"] 315 | optim_type = pl_module.hparams.config["optim_type"] 316 | 317 | optimizer_grouped_parameters = get_grouped_parameters( 318 | pl_module, no_decay, head_names, cross_modal_names, wd, lr, lr_mult_head, lr_mult_cross_modal) 319 | if optim_type == "adamw": 320 | optimizer = AdamW( 321 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98) 322 | ) 323 | elif optim_type == "adam": 324 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 325 | elif optim_type == "sgd": 326 | optimizer = torch.optim.SGD( 327 | optimizer_grouped_parameters, lr=lr, momentum=0.9) 328 | 329 | if pl_module.trainer.max_steps is None: 330 | max_steps = ( 331 | len(pl_module.trainer.datamodule.train_dataloader()) 332 | * pl_module.trainer.max_epochs 333 | // pl_module.trainer.accumulate_grad_batches 334 | ) 335 | else: 336 | max_steps = pl_module.trainer.max_steps 337 | 338 | warmup_steps = pl_module.hparams.config["warmup_steps"] 339 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 340 | warmup_steps = int(max_steps * warmup_steps) 341 | 342 | if decay_power == "cosine": 343 | scheduler = get_cosine_schedule_with_warmup( 344 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 345 | ) 346 | else: 347 | scheduler = get_polynomial_decay_schedule_with_warmup( 348 | optimizer, 349 | num_warmup_steps=warmup_steps, 350 | num_training_steps=max_steps, 351 | lr_end=end_lr, 352 | power=decay_power, 353 | ) 354 | 355 | sched = {"scheduler": scheduler, "interval": "step"} 356 | 357 | return ( 358 | [optimizer], 359 | [sched], 360 | ) 361 | -------------------------------------------------------------------------------- /flm/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | vit_transform, 5 | vit_transform_randaug, 6 | imagenet_transform, 7 | imagenet_transform_randaug, 8 | clip_transform, 9 | clip_transform_randaug, 10 | mae_transform_randaug, 11 | mae_transform, 12 | ) 13 | 14 | _transforms = { 15 | "pixelbert": pixelbert_transform, 16 | "pixelbert_randaug": pixelbert_transform_randaug, 17 | "vit": vit_transform, 18 | "vit_randaug": vit_transform_randaug, 19 | "imagenet": imagenet_transform, 20 | "imagenet_randaug": imagenet_transform_randaug, 21 | "clip": clip_transform, 22 | "clip_randaug": clip_transform_randaug, 23 | 'mae_randaug': mae_transform_randaug, 24 | 'mae': mae_transform, 25 | } 26 | 27 | 28 | def keys_to_transforms(keys: list, size=224): 29 | return [_transforms[key](size=size) for key in keys] 30 | -------------------------------------------------------------------------------- /flm/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL 6 | import PIL.ImageOps 7 | import PIL.ImageEnhance 8 | import PIL.ImageDraw 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | 14 | def ShearX(img, v): # [-0.3, 0.3] 15 | assert -0.3 <= v <= 0.3 16 | if random.random() > 0.5: 17 | v = -v 18 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 19 | 20 | 21 | def ShearY(img, v): # [-0.3, 0.3] 22 | assert -0.3 <= v <= 0.3 23 | if random.random() > 0.5: 24 | v = -v 25 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 26 | 27 | 28 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 29 | assert -0.45 <= v <= 0.45 30 | if random.random() > 0.5: 31 | v = -v 32 | v = v * img.size[0] 33 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 34 | 35 | 36 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 37 | assert 0 <= v 38 | if random.random() > 0.5: 39 | v = -v 40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 41 | 42 | 43 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 44 | assert -0.45 <= v <= 0.45 45 | if random.random() > 0.5: 46 | v = -v 47 | v = v * img.size[1] 48 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 49 | 50 | 51 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 52 | assert 0 <= v 53 | if random.random() > 0.5: 54 | v = -v 55 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 56 | 57 | 58 | def Rotate(img, v): # [-30, 30] 59 | assert -30 <= v <= 30 60 | if random.random() > 0.5: 61 | v = -v 62 | return img.rotate(v) 63 | 64 | 65 | def AutoContrast(img, _): 66 | return PIL.ImageOps.autocontrast(img) 67 | 68 | 69 | def Invert(img, _): 70 | return PIL.ImageOps.invert(img) 71 | 72 | 73 | def Equalize(img, _): 74 | return PIL.ImageOps.equalize(img) 75 | 76 | 77 | def Flip(img, _): # not from the paper 78 | return PIL.ImageOps.mirror(img) 79 | 80 | 81 | def Solarize(img, v): # [0, 256] 82 | assert 0 <= v <= 256 83 | return PIL.ImageOps.solarize(img, v) 84 | 85 | 86 | def SolarizeAdd(img, addition=0, threshold=128): 87 | img_np = np.array(img).astype(np.int) 88 | img_np = img_np + addition 89 | img_np = np.clip(img_np, 0, 255) 90 | img_np = img_np.astype(np.uint8) 91 | img = Image.fromarray(img_np) 92 | return PIL.ImageOps.solarize(img, threshold) 93 | 94 | 95 | def Posterize(img, v): # [4, 8] 96 | v = int(v) 97 | v = max(1, v) 98 | return PIL.ImageOps.posterize(img, v) 99 | 100 | 101 | def Contrast(img, v): # [0.1,1.9] 102 | assert 0.1 <= v <= 1.9 103 | return PIL.ImageEnhance.Contrast(img).enhance(v) 104 | 105 | 106 | def Color(img, v): # [0.1,1.9] 107 | assert 0.1 <= v <= 1.9 108 | return PIL.ImageEnhance.Color(img).enhance(v) 109 | 110 | 111 | def Brightness(img, v): # [0.1,1.9] 112 | assert 0.1 <= v <= 1.9 113 | return PIL.ImageEnhance.Brightness(img).enhance(v) 114 | 115 | 116 | def Sharpness(img, v): # [0.1,1.9] 117 | assert 0.1 <= v <= 1.9 118 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 119 | 120 | 121 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 122 | assert 0.0 <= v <= 0.2 123 | if v <= 0.0: 124 | return img 125 | 126 | v = v * img.size[0] 127 | return CutoutAbs(img, v) 128 | 129 | 130 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 131 | # assert 0 <= v <= 20 132 | if v < 0: 133 | return img 134 | w, h = img.size 135 | x0 = np.random.uniform(w) 136 | y0 = np.random.uniform(h) 137 | 138 | x0 = int(max(0, x0 - v / 2.0)) 139 | y0 = int(max(0, y0 - v / 2.0)) 140 | x1 = min(w, x0 + v) 141 | y1 = min(h, y0 + v) 142 | 143 | xy = (x0, y0, x1, y1) 144 | color = (125, 123, 114) 145 | # color = (0, 0, 0) 146 | img = img.copy() 147 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 148 | return img 149 | 150 | 151 | def SamplePairing(imgs): # [0, 0.4] 152 | def f(img1, v): 153 | i = np.random.choice(len(imgs)) 154 | img2 = PIL.Image.fromarray(imgs[i]) 155 | return PIL.Image.blend(img1, img2, v) 156 | 157 | return f 158 | 159 | 160 | def Identity(img, v): 161 | return img 162 | 163 | 164 | def augment_list(): # 16 oeprations and their ranges 165 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 166 | # l = [ 167 | # (Identity, 0., 1.0), 168 | # (ShearX, 0., 0.3), # 0 169 | # (ShearY, 0., 0.3), # 1 170 | # (TranslateX, 0., 0.33), # 2 171 | # (TranslateY, 0., 0.33), # 3 172 | # (Rotate, 0, 30), # 4 173 | # (AutoContrast, 0, 1), # 5 174 | # (Invert, 0, 1), # 6 175 | # (Equalize, 0, 1), # 7 176 | # (Solarize, 0, 110), # 8 177 | # (Posterize, 4, 8), # 9 178 | # # (Contrast, 0.1, 1.9), # 10 179 | # (Color, 0.1, 1.9), # 11 180 | # (Brightness, 0.1, 1.9), # 12 181 | # (Sharpness, 0.1, 1.9), # 13 182 | # # (Cutout, 0, 0.2), # 14 183 | # # (SamplePairing(imgs), 0, 0.4), # 15 184 | # ] 185 | 186 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 187 | l = [ 188 | (AutoContrast, 0, 1), 189 | (Equalize, 0, 1), 190 | # (Invert, 0, 1), 191 | (Rotate, 0, 30), 192 | (Posterize, 0, 4), 193 | (Solarize, 0, 256), 194 | (SolarizeAdd, 0, 110), 195 | (Color, 0.1, 1.9), 196 | (Contrast, 0.1, 1.9), 197 | (Brightness, 0.1, 1.9), 198 | (Sharpness, 0.1, 1.9), 199 | (ShearX, 0.0, 0.3), 200 | (ShearY, 0.0, 0.3), 201 | # (CutoutAbs, 0, 40), 202 | (TranslateXabs, 0.0, 100), 203 | (TranslateYabs, 0.0, 100), 204 | ] 205 | 206 | return l 207 | 208 | 209 | class Lighting(object): 210 | """Lighting noise(AlexNet - style PCA - based noise)""" 211 | 212 | def __init__(self, alphastd, eigval, eigvec): 213 | self.alphastd = alphastd 214 | self.eigval = torch.Tensor(eigval) 215 | self.eigvec = torch.Tensor(eigvec) 216 | 217 | def __call__(self, img): 218 | if self.alphastd == 0: 219 | return img 220 | 221 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 222 | rgb = ( 223 | self.eigvec.type_as(img) 224 | .clone() 225 | .mul(alpha.view(1, 3).expand(3, 3)) 226 | .mul(self.eigval.view(1, 3).expand(3, 3)) 227 | .sum(1) 228 | .squeeze() 229 | ) 230 | 231 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 232 | 233 | 234 | class CutoutDefault(object): 235 | """ 236 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 237 | """ 238 | 239 | def __init__(self, length): 240 | self.length = length 241 | 242 | def __call__(self, img): 243 | h, w = img.size(1), img.size(2) 244 | mask = np.ones((h, w), np.float32) 245 | y = np.random.randint(h) 246 | x = np.random.randint(w) 247 | 248 | y1 = np.clip(y - self.length // 2, 0, h) 249 | y2 = np.clip(y + self.length // 2, 0, h) 250 | x1 = np.clip(x - self.length // 2, 0, w) 251 | x2 = np.clip(x + self.length // 2, 0, w) 252 | 253 | mask[y1:y2, x1:x2] = 0.0 254 | mask = torch.from_numpy(mask) 255 | mask = mask.expand_as(img) 256 | img *= mask 257 | return img 258 | 259 | 260 | class RandAugment: 261 | def __init__(self, n, m): 262 | self.n = n 263 | self.m = m # [0, 30] 264 | self.augment_list = augment_list() 265 | 266 | def __call__(self, img): 267 | ops = random.choices(self.augment_list, k=self.n) 268 | for op, minval, maxval in ops: 269 | val = (float(self.m) / 30) * float(maxval - minval) + minval 270 | img = op(img, val) 271 | 272 | return img 273 | -------------------------------------------------------------------------------- /flm/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | imagenet_normalize, 4 | MinMaxResize, 5 | ) 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from .randaug import RandAugment 10 | 11 | 12 | def pixelbert_transform(size=800): 13 | longer = int((1333 / 800) * size) 14 | return transforms.Compose( 15 | [ 16 | MinMaxResize(shorter=size, longer=longer), 17 | transforms.ToTensor(), 18 | inception_normalize, 19 | ] 20 | ) 21 | 22 | 23 | def pixelbert_transform_randaug(size=800): 24 | longer = int((1333 / 800) * size) 25 | trs = transforms.Compose( 26 | [ 27 | MinMaxResize(shorter=size, longer=longer), 28 | transforms.ToTensor(), 29 | inception_normalize, 30 | ] 31 | ) 32 | trs.transforms.insert(0, RandAugment(2, 9)) 33 | return trs 34 | 35 | 36 | def imagenet_transform(size=800): 37 | return transforms.Compose( 38 | [ 39 | Resize(size, interpolation=Image.BICUBIC), 40 | CenterCrop(size), 41 | transforms.ToTensor(), 42 | imagenet_normalize, 43 | ] 44 | ) 45 | 46 | 47 | def imagenet_transform_randaug(size=800): 48 | trs = transforms.Compose( 49 | [ 50 | Resize(size, interpolation=Image.BICUBIC), 51 | CenterCrop(size), 52 | transforms.ToTensor(), 53 | imagenet_normalize, 54 | ] 55 | ) 56 | trs.transforms.insert(0, RandAugment(2, 9)) 57 | return trs 58 | 59 | 60 | def vit_transform(size=800): 61 | return transforms.Compose( 62 | [ 63 | Resize(size, interpolation=Image.BICUBIC), 64 | CenterCrop(size), 65 | transforms.ToTensor(), 66 | inception_normalize, 67 | ] 68 | ) 69 | 70 | 71 | def vit_transform_randaug(size=800): 72 | trs = transforms.Compose( 73 | [ 74 | Resize(size, interpolation=Image.BICUBIC), 75 | CenterCrop(size), 76 | transforms.ToTensor(), 77 | inception_normalize, 78 | ] 79 | ) 80 | trs.transforms.insert(0, RandAugment(2, 9)) 81 | return trs 82 | 83 | 84 | def clip_transform(size): 85 | return Compose([ 86 | Resize(size, interpolation=Image.BICUBIC), 87 | CenterCrop(size), 88 | lambda image: image.convert("RGB"), 89 | ToTensor(), 90 | Normalize((0.48145466, 0.4578275, 0.40821073), 91 | (0.26862954, 0.26130258, 0.27577711)), 92 | ]) 93 | 94 | 95 | def clip_transform_randaug(size): 96 | trs = Compose([ 97 | Resize(size, interpolation=Image.BICUBIC), 98 | CenterCrop(size), 99 | lambda image: image.convert("RGB"), 100 | ToTensor(), 101 | Normalize((0.48145466, 0.4578275, 0.40821073), 102 | (0.26862954, 0.26130258, 0.27577711)), 103 | ]) 104 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 105 | trs.transforms.insert(0, RandAugment(2, 9)) 106 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 107 | return trs 108 | 109 | 110 | def mae_transform_randaug(size): 111 | trs = Compose([ 112 | transforms.RandomResizedCrop(size, scale=( 113 | 0.2, 1.0), interpolation=3), # 3 is bicubic 114 | transforms.RandomHorizontalFlip(), 115 | lambda image: image.convert("RGB"), 116 | transforms.ToTensor(), 117 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 118 | 0.229, 0.224, 0.225]) 119 | ]) 120 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 121 | trs.transforms.insert(0, RandAugment(2, 9)) 122 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 123 | return trs 124 | 125 | 126 | def mae_transform(size): 127 | trs = Compose([ 128 | Resize(size, interpolation=Image.BICUBIC), 129 | CenterCrop(size), 130 | lambda image: image.convert("RGB"), 131 | ToTensor(), 132 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 133 | 0.229, 0.224, 0.225]) 134 | ]) 135 | return trs 136 | -------------------------------------------------------------------------------- /flm/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225])] 62 | ) 63 | -------------------------------------------------------------------------------- /flm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/utils/__init__.py -------------------------------------------------------------------------------- /flm/utils/find_newest_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | 5 | 6 | save_folder = sys.argv[1] 7 | exp_name = sys.argv[2] 8 | is_last = True if sys.argv[3] == 'choose_last' else False 9 | 10 | # exp_name = '37_cl_causalflm_scratch_lr5e5_nobias_t0002_NEW_GPU32' 11 | target = '{}/{}_seed*_from*/version_*/checkpoints/epoch*-step*.ckpt'.format( 12 | save_folder, exp_name) 13 | if is_last: 14 | target = '{}/{}_seed*_from*/version_*/checkpoints/last.ckpt'.format( 15 | save_folder, exp_name) 16 | out = glob.glob(target) 17 | 18 | 19 | def get_info(p): 20 | p = p.rstrip('.ckpt') 21 | version = float(p.split('/')[-3].split('_')[-1]) 22 | try: 23 | epoch = float(p.split('/')[-1].split('-')[0].split('_')[1]) 24 | except: 25 | epoch = None 26 | try: 27 | score = float(p.split('/')[-1].split('-')[-1].split('_')[-1]) 28 | except: 29 | score = None 30 | 31 | if score is None: 32 | score = -10000. 33 | 34 | return score, epoch, version 35 | 36 | 37 | out = sorted(out, key=get_info, reverse=True) 38 | 39 | if len(out): 40 | print(out[0]) 41 | -------------------------------------------------------------------------------- /flm/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /flm/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from flm.modules import heads, objectives, meter_utils 4 | 5 | 6 | @torch.no_grad() 7 | def adapt_vocab_size(state_dict, new_vocab_size): 8 | 9 | for name in state_dict.keys(): 10 | if 'embeddings.word_embeddings.weight' in name or 'fusion_token_embedding.word_embeddings.weight' in name: 11 | expand_vocab(name, state_dict, new_vocab_size) 12 | 13 | # value = state_dict[name] 14 | # old_vocab_size, old_embed_dim = value.shape 15 | # if old_vocab_size != new_vocab_size: 16 | # assert new_vocab_size > old_vocab_size 17 | # new_embeddings = nn.Embedding(new_vocab_size, old_embed_dim) 18 | # new_embeddings.apply(objectives.init_weights) 19 | # new_embeddings.weight[:old_vocab_size] = value 20 | # print(' replace vocab size of {} from {} to {}'.format(name ,old_vocab_size, new_vocab_size)) 21 | # state_dict[name] = new_embeddings.weight 22 | 23 | output_params = ['mlm_score', 'lm_score', 'lm_score_r', 'lm_score_f'] 24 | 25 | for p in output_params: 26 | weight_name = p + '.decoder.weight' 27 | bias_name = p + '.bias' 28 | if weight_name in name or bias_name in name: 29 | expand_vocab(name, state_dict, new_vocab_size) 30 | 31 | return state_dict 32 | 33 | 34 | def expand_vocab(name, state_dict, new_vocab_size): 35 | value = state_dict[name] 36 | if value.shape[0] != new_vocab_size: 37 | state_dict[name] = expand_tensor(value, new_vocab_size) 38 | print(' replace vocab size of {} from {} to {}'.format( 39 | name, value.shape[0], new_vocab_size)) 40 | 41 | 42 | def expand_tensor(value, new_vocab_size): 43 | if value.ndim == 1: 44 | old_vocab_size = value.shape[0] 45 | new_embeddings = torch.zeros(new_vocab_size) 46 | else: 47 | old_vocab_size, old_embed_dim = value.shape 48 | new_embeddings = torch.zeros(new_vocab_size, old_embed_dim) 49 | assert new_vocab_size > old_vocab_size 50 | 51 | new_embeddings.data.normal_(mean=0.0, std=0.02) 52 | 53 | new_embeddings[:old_vocab_size] = value 54 | return new_embeddings 55 | -------------------------------------------------------------------------------- /flm/utils/whole_word_masking.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from dataclasses import dataclass 4 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 5 | 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | # from ..file_utils import PaddingStrategy 10 | # from ..modeling_utils import PreTrainedModel 11 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 12 | 13 | from transformers import ( 14 | DataCollatorForLanguageModeling) 15 | 16 | 17 | class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): 18 | """ 19 | Data collator used for language modeling. 20 | 21 | - collates batches of tensors, honoring their tokenizer's pad_token 22 | - preprocesses batches for masked language modeling 23 | """ 24 | 25 | def __call__( 26 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 27 | ) -> Dict[str, torch.Tensor]: 28 | if isinstance(examples[0], (dict, BatchEncoding)): 29 | input_ids = [e["input_ids"] for e in examples] 30 | else: 31 | input_ids = examples 32 | examples = [{"input_ids": e} for e in examples] 33 | 34 | batch_input = _collate_batch(input_ids, self.tokenizer) 35 | 36 | mask_labels = [] 37 | for e in examples: 38 | ref_tokens = [] 39 | for id in tolist(e["input_ids"]): 40 | token = self.tokenizer._convert_id_to_token(id) 41 | if id == self.tokenizer.convert_tokens_to_ids(''): 42 | token = '' 43 | if id == self.tokenizer.convert_tokens_to_ids(''): 44 | token = '' 45 | ref_tokens.append(token) 46 | 47 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] 48 | if "chinese_ref" in e: 49 | ref_pos = tolist(e["chinese_ref"]) 50 | len_seq = len(e["input_ids"]) 51 | for i in range(len_seq): 52 | if i in ref_pos: 53 | ref_tokens[i] = "##" + ref_tokens[i] 54 | mask_labels.append(self._whole_word_mask(ref_tokens)) 55 | batch_mask = _collate_batch(mask_labels, self.tokenizer) 56 | inputs, labels = self.mask_tokens(batch_input, batch_mask) 57 | return {"input_ids": inputs, "labels": labels} 58 | 59 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): 60 | """ 61 | Get 0/1 labels for masked tokens with whole word mask proxy 62 | """ 63 | 64 | cand_indexes = [] 65 | 66 | for (i, token) in enumerate(input_tokens): 67 | if token == "[CLS]" or token == "[SEP]": 68 | continue 69 | 70 | if len(cand_indexes) >= 1 and token.startswith("##"): 71 | cand_indexes[-1].append(i) 72 | else: 73 | cand_indexes.append([i]) 74 | 75 | random.shuffle(cand_indexes) 76 | num_to_predict = min(max_predictions, max( 77 | 1, int(round(len(input_tokens) * self.mlm_probability)))) 78 | masked_lms = [] 79 | covered_indexes = set() 80 | for index_set in cand_indexes: 81 | if len(masked_lms) >= num_to_predict: 82 | break 83 | # If adding a whole-word mask would exceed the maximum number of 84 | # predictions, then just skip this candidate. 85 | if len(masked_lms) + len(index_set) > num_to_predict: 86 | continue 87 | is_any_index_covered = False 88 | for index in index_set: 89 | if index in covered_indexes: 90 | is_any_index_covered = True 91 | break 92 | if is_any_index_covered: 93 | continue 94 | for index in index_set: 95 | covered_indexes.add(index) 96 | masked_lms.append(index) 97 | 98 | assert len(covered_indexes) == len(masked_lms) 99 | mask_labels = [ 100 | 1 if i in covered_indexes else 0 for i in range(len(input_tokens))] 101 | return mask_labels 102 | 103 | def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 104 | """ 105 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 106 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 107 | """ 108 | 109 | if self.tokenizer.mask_token is None: 110 | raise ValueError( 111 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." 112 | ) 113 | labels = inputs.clone() 114 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 115 | 116 | probability_matrix = mask_labels 117 | 118 | special_tokens_mask = [ 119 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 120 | ] 121 | probability_matrix.masked_fill_(torch.tensor( 122 | special_tokens_mask, dtype=torch.bool), value=0.0) 123 | if self.tokenizer._pad_token is not None: 124 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 125 | probability_matrix.masked_fill_(padding_mask, value=0.0) 126 | 127 | masked_indices = probability_matrix.bool() 128 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 129 | 130 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 131 | indices_replaced = torch.bernoulli(torch.full( 132 | labels.shape, 0.8)).bool() & masked_indices 133 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( 134 | self.tokenizer.mask_token) 135 | 136 | # 10% of the time, we replace masked input tokens with random word 137 | indices_random = torch.bernoulli(torch.full( 138 | labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 139 | random_words = torch.randint( 140 | len(self.tokenizer), labels.shape, dtype=torch.long) 141 | inputs[indices_random] = random_words[indices_random] 142 | 143 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 144 | return inputs, labels 145 | 146 | 147 | def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): 148 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" 149 | # Tensorize if necessary. 150 | if isinstance(examples[0], (list, tuple)): 151 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 152 | 153 | # Check if padding is necessary. 154 | length_of_first = examples[0].size(0) 155 | are_tensors_same_length = all( 156 | x.size(0) == length_of_first for x in examples) 157 | if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): 158 | return torch.stack(examples, dim=0) 159 | 160 | # If yes, check if we have a `pad_token`. 161 | if tokenizer._pad_token is None: 162 | raise ValueError( 163 | "You are attempting to pad samples but the tokenizer you are using" 164 | f" ({tokenizer.__class__.__name__}) does not have a pad token." 165 | ) 166 | 167 | # Creating the full tensor and filling it with our data. 168 | max_length = max(x.size(0) for x in examples) 169 | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 170 | max_length = ((max_length // pad_to_multiple_of) + 1) * \ 171 | pad_to_multiple_of 172 | result = examples[0].new_full( 173 | [len(examples), max_length], tokenizer.pad_token_id) 174 | for i, example in enumerate(examples): 175 | if tokenizer.padding_side == "right": 176 | result[i, : example.shape[0]] = example 177 | else: 178 | result[i, -example.shape[0]:] = example 179 | return result 180 | 181 | 182 | def tolist(x: Union[List[Any], torch.Tensor]): 183 | return x.tolist() if isinstance(x, torch.Tensor) else x 184 | -------------------------------------------------------------------------------- /flm/utils/write_coco_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import pyarrow as pa 5 | import random 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | with open(path, "rb") as fp: 15 | binary = fp.read() 16 | captions = iid2captions[name] 17 | split = iid2split[name] 18 | return [binary, captions, name, split] 19 | 20 | 21 | def make_arrow(root, dataset_root): 22 | with open(f"{root}/karpathy/dataset_coco.json", "r") as fp: 23 | captions = json.load(fp) 24 | 25 | captions = captions["images"] 26 | 27 | iid2captions = defaultdict(list) 28 | iid2split = dict() 29 | 30 | for cap in tqdm(captions): 31 | filename = cap["filename"] 32 | iid2split[filename] = cap["split"] 33 | for c in cap["sentences"]: 34 | iid2captions[filename].append(c["raw"]) 35 | 36 | paths = list(glob(f"{root}/train2014/*.jpg")) + \ 37 | list(glob(f"{root}/val2014/*.jpg")) 38 | random.shuffle(paths) 39 | caption_paths = [path for path in paths if path.split( 40 | "/")[-1] in iid2captions] 41 | 42 | if len(paths) == len(caption_paths): 43 | print("all images have caption annotations") 44 | else: 45 | print("not all images have caption annotations") 46 | print( 47 | len(paths), len(caption_paths), len(iid2captions), 48 | ) 49 | 50 | bs = [path2rest(path, iid2captions, iid2split) 51 | for path in tqdm(caption_paths)] 52 | 53 | for split in ["train", "val", "restval", "test"]: 54 | batches = [b for b in bs if b[-1] == split] 55 | 56 | dataframe = pd.DataFrame( 57 | batches, columns=["image", "caption", "image_id", "split"], 58 | ) 59 | 60 | table = pa.Table.from_pandas(dataframe) 61 | os.makedirs(dataset_root, exist_ok=True) 62 | with pa.OSFile( 63 | f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb" 64 | ) as sink: 65 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 66 | writer.write_table(table) 67 | -------------------------------------------------------------------------------- /flm/utils/write_conceptual_caption.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | split, _, name = path.split("/")[-3:] 14 | split = split.split("_")[-1] 15 | iid = name 16 | 17 | with open(path, "rb") as fp: 18 | binary = fp.read() 19 | 20 | captions = iid2captions[iid] 21 | 22 | return [ 23 | binary, 24 | captions, 25 | iid, 26 | split, 27 | ] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | for split in ["val", "train"]: 32 | with open(f"{root}/{split}_annot.json", "r") as fp: 33 | captions = json.load(fp) 34 | 35 | iid2captions = dict() 36 | for cap in tqdm(captions): 37 | iid = cap[0].split("/")[-1] 38 | iid2captions[iid] = [cap[1]] 39 | 40 | paths = list(glob(f"{root}/images_{split}/*/*")) 41 | random.shuffle(paths) 42 | caption_paths = [path for path in paths if path.split( 43 | "/")[-1] in iid2captions] 44 | if len(paths) == len(caption_paths): 45 | print("all images have caption annotations") 46 | else: 47 | print("not all images have caption annotations") 48 | print( 49 | len(paths), len(caption_paths), len(iid2captions), 50 | ) 51 | arrow_path = "{dataset_root}/conceptual_caption_{split}_{sub}.arrow" 52 | write_split(caption_paths, iid2captions, 53 | dataset_root, arrow_path, split) 54 | 55 | 56 | def write_split(caption_paths, iid2captions, dataset_root, arrow_path, split): 57 | sub_len = int(len(caption_paths) // 100000) 58 | subs = list(range(sub_len + 1)) 59 | for sub in subs: 60 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000] 61 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)] 62 | dataframe = pd.DataFrame( 63 | bs, columns=["image", "caption", "image_id", "split"], 64 | ) 65 | 66 | table = pa.Table.from_pandas(dataframe) 67 | 68 | with pa.OSFile( 69 | arrow_path.format(**{'dataset_root': dataset_root, 70 | 'split': split, 71 | 'sub': sub}), "wb") as sink: 72 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 73 | writer.write_table(table) 74 | del dataframe 75 | del table 76 | del bs 77 | gc.collect() 78 | -------------------------------------------------------------------------------- /flm/utils/write_conceptual_caption12M_cloud.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions, data_dir, split): 13 | # split, _, name = path.split("/")[-3:] 14 | # split = split.split("_")[-1] 15 | # iid = name 16 | iid = path 17 | 18 | with open(_get_video_path(path, data_dir, split)[0], "rb") as fp: 19 | binary = fp.read() 20 | 21 | captions = iid2captions[iid] 22 | 23 | return [ 24 | binary, 25 | captions, 26 | iid, 27 | split, 28 | ] 29 | 30 | 31 | def _get_caption(sample): 32 | return sample[0] 33 | 34 | 35 | def _get_video_path(file_name, data_dir, split): 36 | # conceptual captions uses this hashing to create the filename 37 | rel_dir = '.' 38 | # if split != 'train': 39 | # rel_dir = 'validation' 40 | rel_fp = os.path.join(rel_dir, file_name) 41 | return os.path.join(data_dir, rel_fp), rel_fp 42 | 43 | 44 | def make_arrow(dataset_root, save_folder, split='train', chunk_id=0, chunk_num=1): 45 | 46 | metadata_dir = os.path.join(dataset_root, 'metadata') 47 | split_files = { 48 | 'train': 'train.tsv', 49 | 'val': 'val.tsv', # there is no test 50 | } 51 | split_folders = {'train': 'training', 52 | 'val': 'validation', # there is no tes 53 | } 54 | 55 | # for split in ["val", "train"]: 56 | if True: 57 | target_split_fp = split_files[split] 58 | metadata = pd.read_csv(os.path.join( 59 | metadata_dir, target_split_fp), sep='\t') 60 | 61 | # meta_data_path = f"{root}/metadata/cc3m_{split_files}_success_full.tsv" 62 | # metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 63 | 64 | # with open(, "r") as fp: 65 | # captions = json.load(fp) 66 | 67 | # iid2captions = dict() 68 | # for cap in tqdm(captions): 69 | # iid = cap[0].split("/")[-1] 70 | # iid2captions[iid] = [cap[1]] 71 | 72 | if True: 73 | chunk_size = metadata.shape[0] // chunk_num + 1 74 | start, end = chunk_id * chunk_size, (chunk_id + 1) * chunk_size 75 | print('chunk number: {}, current chunk_id: {}, chunk_size: {}'.format( 76 | chunk_num, chunk_id, chunk_size)) 77 | 78 | iid2captions = dict() 79 | for item in tqdm(range(metadata.shape[0])): 80 | if item not in range(start, end): 81 | continue 82 | sample = metadata.iloc[item] 83 | caption = _get_caption(sample) 84 | iid = sample[1] 85 | iid2captions[iid] = caption 86 | 87 | # paths = list(glob(f"{dataset_root}/{split_folders[split]}/*")) 88 | # random.shuffle(paths) 89 | 90 | # caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 91 | caption_paths = list(iid2captions.keys()) 92 | # random.shuffle(caption_paths) 93 | 94 | # if len(paths) == len(caption_paths): 95 | # print("all images have caption annotations") 96 | # else: 97 | # print("not all images have caption annotations") 98 | # print( 99 | # len(paths), len(caption_paths), len(iid2captions), 100 | # ) 101 | 102 | sub_len = int(len(caption_paths) // 100000) 103 | subs = list(range(sub_len + 1)) 104 | print('split number: {}, split_len: {}'.format(sub_len, 100000)) 105 | for sub in tqdm(subs): 106 | if sub > 0: 107 | continue 108 | print('current split id: {}'.format(sub)) 109 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000] 110 | bs = [path2rest(path, iid2captions, dataset_root, split) 111 | for path in tqdm(sub_paths)] 112 | 113 | dataframe = pd.DataFrame( 114 | bs, columns=["image", "caption", "image_id", "split"], 115 | ) 116 | 117 | table = pa.Table.from_pandas(dataframe) 118 | 119 | os.makedirs(save_folder, exist_ok=True) 120 | dst_arrow_file = f"{save_folder}/conceptual_caption12M_{split}_{chunk_id}_{sub}.arrow" 121 | with pa.OSFile( 122 | dst_arrow_file, "wb" 123 | ) as sink: 124 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 125 | writer.write_table(table) 126 | del dataframe 127 | del table 128 | del bs 129 | gc.collect() 130 | -------------------------------------------------------------------------------- /flm/utils/write_conceptual_caption_cloud.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions, data_dir, split): 13 | # split, _, name = path.split("/")[-3:] 14 | # split = split.split("_")[-1] 15 | # iid = name 16 | iid = path 17 | 18 | with open(_get_video_path(path, data_dir, split)[0], "rb") as fp: 19 | binary = fp.read() 20 | 21 | captions = iid2captions[iid] 22 | 23 | return [ 24 | binary, 25 | captions, 26 | iid, 27 | split, 28 | ] 29 | 30 | 31 | def _get_caption(sample): 32 | return sample[0] 33 | 34 | 35 | def _get_video_path(file_name, data_dir, split): 36 | # conceptual captions uses this hashing to create the filename 37 | rel_dir = 'training' 38 | if split != 'train': 39 | rel_dir = 'validation' 40 | rel_fp = os.path.join(rel_dir, file_name) 41 | return os.path.join(data_dir, rel_fp), rel_fp 42 | 43 | 44 | def make_arrow(dataset_root, save_folder, split='train'): 45 | metadata_dir = os.path.join(dataset_root, 'metadata') 46 | split_files = { 47 | 'train': 'cc3m_training_success_full.tsv', 48 | 'val': 'cc3m_validation_success_full.tsv', # there is no test 49 | } 50 | split_folders = {'train': 'training', 51 | 'val': 'validation', # there is no tes 52 | } 53 | 54 | # for split in ["val", "train"]: 55 | if True: 56 | target_split_fp = split_files[split] 57 | metadata = pd.read_csv(os.path.join( 58 | metadata_dir, target_split_fp), sep='\t') 59 | 60 | # meta_data_path = f"{root}/metadata/cc3m_{split_files}_success_full.tsv" 61 | # metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 62 | 63 | # with open(, "r") as fp: 64 | # captions = json.load(fp) 65 | 66 | # iid2captions = dict() 67 | # for cap in tqdm(captions): 68 | # iid = cap[0].split("/")[-1] 69 | # iid2captions[iid] = [cap[1]] 70 | 71 | iid2captions = dict() 72 | for item in range(metadata.shape[0]): 73 | sample = metadata.iloc[item] 74 | caption = _get_caption(sample) 75 | iid = sample[1] 76 | iid2captions[iid] = caption 77 | 78 | # paths = list(glob(f"{dataset_root}/{split_folders[split]}/*")) 79 | # random.shuffle(paths) 80 | 81 | # caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 82 | caption_paths = list(iid2captions.keys()) 83 | random.shuffle(caption_paths) 84 | 85 | # if len(paths) == len(caption_paths): 86 | # print("all images have caption annotations") 87 | # else: 88 | # print("not all images have caption annotations") 89 | # print( 90 | # len(paths), len(caption_paths), len(iid2captions), 91 | # ) 92 | 93 | sub_len = int(len(caption_paths) // 100000) 94 | subs = list(range(sub_len + 1)) 95 | for sub in subs: 96 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000] 97 | bs = [path2rest(path, iid2captions, dataset_root, split) 98 | for path in tqdm(sub_paths)] 99 | 100 | dataframe = pd.DataFrame( 101 | bs, columns=["image", "caption", "image_id", "split"], 102 | ) 103 | 104 | table = pa.Table.from_pandas(dataframe) 105 | 106 | os.makedirs(save_folder, exist_ok=True) 107 | with pa.OSFile( 108 | f"{save_folder}/conceptual_caption_{split}_{sub}.arrow", "wb" 109 | ) as sink: 110 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 111 | writer.write_table(table) 112 | del dataframe 113 | del table 114 | del bs 115 | gc.collect() 116 | -------------------------------------------------------------------------------- /flm/utils/write_f30k_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | 15 | with open(path, "rb") as fp: 16 | binary = fp.read() 17 | 18 | captions = iid2captions[name] 19 | split = iid2split[name] 20 | 21 | return [binary, captions, name, split] 22 | 23 | 24 | def make_arrow(root, dataset_root): 25 | with open(f"{root}/karpathy/dataset_flickr30k.json", "r") as fp: 26 | captions = json.load(fp) 27 | 28 | captions = captions["images"] 29 | 30 | iid2captions = defaultdict(list) 31 | iid2split = dict() 32 | 33 | for cap in tqdm(captions): 34 | filename = cap["filename"] 35 | iid2split[filename] = cap["split"] 36 | for c in cap["sentences"]: 37 | iid2captions[filename].append(c["raw"]) 38 | 39 | paths = list(glob(f"{root}/flickr30k-images/*.jpg")) 40 | random.shuffle(paths) 41 | caption_paths = [path for path in paths if path.split( 42 | "/")[-1] in iid2captions] 43 | 44 | if len(paths) == len(caption_paths): 45 | print("all images have caption annotations") 46 | else: 47 | print("not all images have caption annotations") 48 | print( 49 | len(paths), len(caption_paths), len(iid2captions), 50 | ) 51 | 52 | bs = [path2rest(path, iid2captions, iid2split) 53 | for path in tqdm(caption_paths)] 54 | 55 | for split in ["train", "val", "test"]: 56 | batches = [b for b in bs if b[-1] == split] 57 | 58 | dataframe = pd.DataFrame( 59 | batches, columns=["image", "caption", "image_id", "split"], 60 | ) 61 | 62 | table = pa.Table.from_pandas(dataframe) 63 | 64 | os.makedirs(dataset_root, exist_ok=True) 65 | with pa.OSFile( 66 | f"{dataset_root}/f30k_caption_karpathy_{split}.arrow", "wb" 67 | ) as sink: 68 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 69 | writer.write_table(table) 70 | -------------------------------------------------------------------------------- /flm/utils/write_nlvr2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | def process(root, iden, row): 11 | texts = [r["sentence"] for r in row] 12 | labels = [r["label"] for r in row] 13 | 14 | split = iden.split("-")[0] 15 | 16 | if iden.startswith("train"): 17 | directory = row[0]["directory"] 18 | path = f"{root}/images/train/{directory}/{iden}" 19 | else: 20 | path = f"{root}/{split}/{iden}" 21 | 22 | with open(f"{path}-img0.png", "rb") as fp: 23 | img0 = fp.read() 24 | with open(f"{path}-img1.png", "rb") as fp: 25 | img1 = fp.read() 26 | 27 | return [img0, img1, texts, labels, iden] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | train_data = list( 32 | map(json.loads, open(f"{root}/nlvr2/data/train.json").readlines()) 33 | ) 34 | test1_data = list( 35 | map(json.loads, open(f"{root}/nlvr2/data/test1.json").readlines()) 36 | ) 37 | dev_data = list(map(json.loads, open( 38 | f"{root}/nlvr2/data/dev.json").readlines())) 39 | 40 | balanced_test1_data = list( 41 | map( 42 | json.loads, 43 | open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(), 44 | ) 45 | ) 46 | balanced_dev_data = list( 47 | map( 48 | json.loads, 49 | open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(), 50 | ) 51 | ) 52 | 53 | unbalanced_test1_data = list( 54 | map( 55 | json.loads, 56 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(), 57 | ) 58 | ) 59 | unbalanced_dev_data = list( 60 | map( 61 | json.loads, 62 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(), 63 | ) 64 | ) 65 | 66 | splits = [ 67 | "train", 68 | "dev", 69 | "test1", 70 | "balanced_dev", 71 | "balanced_test1", 72 | "unbalanced_dev", 73 | "unbalanced_test1", 74 | ] 75 | 76 | datas = [ 77 | train_data, 78 | dev_data, 79 | test1_data, 80 | balanced_dev_data, 81 | balanced_test1_data, 82 | unbalanced_dev_data, 83 | unbalanced_test1_data, 84 | ] 85 | 86 | annotations = dict() 87 | 88 | for split, data in zip(splits, datas): 89 | _annot = defaultdict(list) 90 | for row in tqdm(data): 91 | _annot["-".join(row["identifier"].split("-")[:-1])].append(row) 92 | annotations[split] = _annot 93 | 94 | for split in splits: 95 | bs = [ 96 | process(root, iden, row) for iden, row in tqdm(annotations[split].items()) 97 | ] 98 | 99 | dataframe = pd.DataFrame( 100 | bs, columns=["image_0", "image_1", 101 | "questions", "answers", "identifier"], 102 | ) 103 | 104 | table = pa.Table.from_pandas(dataframe) 105 | 106 | os.makedirs(dataset_root, exist_ok=True) 107 | with pa.OSFile(f"{dataset_root}/nlvr2_{split}.arrow", "wb") as sink: 108 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 109 | writer.write_table(table) 110 | -------------------------------------------------------------------------------- /flm/utils/write_sbu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | from .write_conceptual_caption import write_split 11 | 12 | 13 | def path2rest(path, iid2captions): 14 | split, _, name = path.split("/")[-3:] 15 | split = split.split("_")[-1] 16 | iid = name 17 | 18 | with open(path, "rb") as fp: 19 | binary = fp.read() 20 | 21 | captions = iid2captions[iid] 22 | 23 | return [ 24 | binary, 25 | captions, 26 | iid, 27 | split, 28 | ] 29 | 30 | 31 | def make_arrow(root, dataset_root): 32 | with open(f"{root}/annot.json", "r") as fp: 33 | captions = json.load(fp) 34 | 35 | iid2captions = dict() 36 | for cap in tqdm(captions): 37 | iid = cap[0].split("/")[-1] 38 | iid2captions[iid] = [cap[1]] 39 | 40 | paths = list(glob(f"{root}/images_train/*/*")) 41 | random.shuffle(paths) 42 | caption_paths = [path for path in paths if path.split( 43 | "/")[-1] in iid2captions] 44 | if len(paths) == len(caption_paths): 45 | print("all images have caption annotations") 46 | else: 47 | print("not all images have caption annotations") 48 | print( 49 | len(paths), len(caption_paths), len(iid2captions), 50 | ) 51 | 52 | arrow_path = "{dataset_root}/sbu_{sub}.arrow" 53 | write_split(caption_paths, iid2captions, 54 | dataset_root, arrow_path, split=None) 55 | -------------------------------------------------------------------------------- /flm/utils/write_snli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2} 11 | 12 | 13 | def process(root, imgid, ann): 14 | with open(f"{root}/Flickr30K/images/{imgid}.jpg", "rb") as fp: 15 | img = fp.read() 16 | 17 | sentences = ann['sentences'] 18 | 19 | labels = ann['labels'] 20 | 21 | return [img, sentences, labels] 22 | 23 | 24 | def make_arrow(root, dataset_root): 25 | train_data = list( 26 | map(json.loads, open(f"{root}/snli_ve_train.jsonl").readlines()) 27 | ) 28 | test_data = list( 29 | map(json.loads, open(f"{root}/snli_ve_test.jsonl").readlines()) 30 | ) 31 | dev_data = list( 32 | map(json.loads, open(f"{root}/snli_ve_dev.jsonl").readlines()) 33 | ) 34 | 35 | splits = [ 36 | "train", 37 | "dev", 38 | "test", 39 | ] 40 | 41 | annotations = dict() 42 | annotations['train'] = train_data 43 | annotations['dev'] = dev_data 44 | annotations['test'] = test_data 45 | annots = dict() 46 | for split in splits: 47 | annots[split] = {} 48 | for line in annotations[split]: 49 | imgid = line['Flickr30K_ID'] 50 | if not imgid in annots[split]: 51 | annots[split][imgid] = {} 52 | annots[split][imgid]['sentences'] = [] 53 | annots[split][imgid]['labels'] = [] 54 | annots[split][imgid]['sentences'].append( 55 | [line['sentence1'], line['sentence2']]) 56 | annots[split][imgid]['labels'].append(label2id[line['gold_label']]) 57 | 58 | for split in splits: 59 | bs = [process(root, imgid, annots[split][imgid]) 60 | for imgid in tqdm(annots[split])] 61 | 62 | dataframe = pd.DataFrame( 63 | bs, columns=["image", "sentences", "labels"] 64 | ) 65 | 66 | table = pa.Table.from_pandas(dataframe) 67 | 68 | os.makedirs(dataset_root, exist_ok=True) 69 | with pa.OSFile(f"{dataset_root}/snli_{split}.arrow", "wb") as sink: 70 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 71 | writer.write_table(table) 72 | -------------------------------------------------------------------------------- /flm/utils/write_vg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | name = path.split("/")[-1] 14 | iid = int(name[:-4]) 15 | 16 | with open(path, "rb") as fp: 17 | binary = fp.read() 18 | 19 | cdicts = iid2captions[iid] 20 | captions = [c["phrase"] for c in cdicts] 21 | widths = [c["width"] for c in cdicts] 22 | heights = [c["height"] for c in cdicts] 23 | xs = [c["x"] for c in cdicts] 24 | ys = [c["y"] for c in cdicts] 25 | 26 | return [ 27 | binary, 28 | captions, 29 | widths, 30 | heights, 31 | xs, 32 | ys, 33 | str(iid), 34 | ] 35 | 36 | 37 | def make_arrow(root, dataset_root): 38 | with open(f"{root}/annotations/region_descriptions.json", "r") as fp: 39 | captions = json.load(fp) 40 | 41 | iid2captions = defaultdict(list) 42 | for cap in tqdm(captions): 43 | cap = cap["regions"] 44 | for c in cap: 45 | iid2captions[c["image_id"]].append(c) 46 | 47 | paths = list(glob(f"{root}/images/VG_100K/*.jpg")) + list( 48 | glob(f"{root}/images/VG_100K_2/*.jpg") 49 | ) 50 | random.shuffle(paths) 51 | caption_paths = [ 52 | path for path in paths if int(path.split("/")[-1][:-4]) in iid2captions 53 | ] 54 | 55 | if len(paths) == len(caption_paths): 56 | print("all images have caption annotations") 57 | else: 58 | print("not all images have caption annotations") 59 | print( 60 | len(paths), len(caption_paths), len(iid2captions), 61 | ) 62 | 63 | bs = [path2rest(path, iid2captions) for path in tqdm(caption_paths)] 64 | dataframe = pd.DataFrame( 65 | bs, columns=["image", "caption", "width", 66 | "height", "x", "y", "image_id"], 67 | ) 68 | table = pa.Table.from_pandas(dataframe) 69 | 70 | os.makedirs(dataset_root, exist_ok=True) 71 | with pa.OSFile(f"{dataset_root}/vg.arrow", "wb") as sink: 72 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 73 | writer.write_table(table) 74 | -------------------------------------------------------------------------------- /flm/utils/write_vqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | 13 | def get_score(occurences): 14 | if occurences == 0: 15 | return 0.0 16 | elif occurences == 1: 17 | return 0.3 18 | elif occurences == 2: 19 | return 0.6 20 | elif occurences == 3: 21 | return 0.9 22 | else: 23 | return 1.0 24 | 25 | 26 | def path2rest(path, split, annotations, label2ans): 27 | iid = int(path.split("/")[-1].split("_")[-1][:-4]) 28 | 29 | with open(path, "rb") as fp: 30 | binary = fp.read() 31 | 32 | _annot = annotations[split][iid] 33 | _annot = list(_annot.items()) 34 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot] 35 | questions = [qa[0] for qa in qas] 36 | answers = [qa[1] for qa in qas] if "test" not in split else list(list()) 37 | answer_labels = ( 38 | [a["labels"] for a in answers] if "test" not in split else list(list()) 39 | ) 40 | answer_scores = ( 41 | [a["scores"] for a in answers] if "test" not in split else list(list()) 42 | ) 43 | answers = ( 44 | [[label2ans[l] for l in al] for al in answer_labels] 45 | if "test" not in split 46 | else list(list()) 47 | ) 48 | 49 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split] 50 | 51 | 52 | def make_arrow(root, dataset_root): 53 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp: 54 | questions_train2014 = json.load(fp)["questions"] 55 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp: 56 | questions_val2014 = json.load(fp)["questions"] 57 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp: 58 | questions_test2015 = json.load(fp)["questions"] 59 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp: 60 | questions_test_dev2015 = json.load(fp)["questions"] 61 | 62 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp: 63 | annotations_train2014 = json.load(fp)["annotations"] 64 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp: 65 | annotations_val2014 = json.load(fp)["annotations"] 66 | 67 | annotations = dict() 68 | 69 | for split, questions in zip( 70 | ["train", "val", "test", "test-dev"], 71 | [ 72 | questions_train2014, 73 | questions_val2014, 74 | questions_test2015, 75 | questions_test_dev2015, 76 | ], 77 | ): 78 | _annot = defaultdict(dict) 79 | for q in tqdm(questions): 80 | _annot[q["image_id"]][q["question_id"]] = [q["question"]] 81 | 82 | annotations[split] = _annot 83 | 84 | all_major_answers = list() 85 | 86 | for split, annots in zip( 87 | ["train", "val"], [annotations_train2014, annotations_val2014], 88 | ): 89 | _annot = annotations[split] 90 | for q in tqdm(annots): 91 | all_major_answers.append(q["multiple_choice_answer"]) 92 | 93 | all_major_answers = [normalize_word(word) 94 | for word in tqdm(all_major_answers)] 95 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} 96 | ans2label = {k: i for i, k in enumerate(counter.keys())} 97 | label2ans = list(counter.keys()) 98 | 99 | for split, annots in zip( 100 | ["train", "val"], [annotations_train2014, annotations_val2014], 101 | ): 102 | _annot = annotations[split] 103 | for q in tqdm(annots): 104 | answers = q["answers"] 105 | answer_count = {} 106 | for answer in answers: 107 | answer_ = answer["answer"] 108 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 109 | 110 | labels = [] 111 | scores = [] 112 | for answer in answer_count: 113 | if answer not in ans2label: 114 | continue 115 | labels.append(ans2label[answer]) 116 | score = get_score(answer_count[answer]) 117 | scores.append(score) 118 | 119 | _annot[q["image_id"]][q["question_id"]].append( 120 | {"labels": labels, "scores": scores, } 121 | ) 122 | 123 | for split in ["train", "val"]: 124 | filtered_annot = dict() 125 | for ik, iv in annotations[split].items(): 126 | new_q = dict() 127 | for qk, qv in iv.items(): 128 | if len(qv[1]["labels"]) != 0: 129 | new_q[qk] = qv 130 | if len(new_q) != 0: 131 | filtered_annot[ik] = new_q 132 | annotations[split] = filtered_annot 133 | 134 | for split in [ 135 | "train", 136 | "val", 137 | "test", 138 | "test-dev", 139 | ]: 140 | annot = annotations[split] 141 | split_name = { 142 | "train": "train2014", 143 | "val": "val2014", 144 | "test": "test2015", 145 | "test-dev": "test2015", 146 | }[split] 147 | paths = list(glob(f"{root}/{split_name}/*.jpg")) 148 | random.shuffle(paths) 149 | annot_paths = [ 150 | path 151 | for path in paths 152 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot 153 | ] 154 | 155 | if len(paths) == len(annot_paths): 156 | print("all images have caption annotations") 157 | else: 158 | print("not all images have caption annotations") 159 | print( 160 | len(paths), len(annot_paths), len(annot), 161 | ) 162 | 163 | bs = [ 164 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths) 165 | ] 166 | 167 | dataframe = pd.DataFrame( 168 | bs, 169 | columns=[ 170 | "image", 171 | "questions", 172 | "answers", 173 | "answer_labels", 174 | "answer_scores", 175 | "image_id", 176 | "question_id", 177 | "split", 178 | ], 179 | ) 180 | 181 | table = pa.Table.from_pandas(dataframe) 182 | 183 | os.makedirs(dataset_root, exist_ok=True) 184 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink: 185 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 186 | writer.write_table(table) 187 | 188 | table = pa.ipc.RecordBatchFileReader( 189 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r") 190 | ).read_all() 191 | 192 | pdtable = table.to_pandas() 193 | 194 | df1 = pdtable[:-1000] 195 | df2 = pdtable[-1000:] 196 | 197 | df1 = pa.Table.from_pandas(df1) 198 | df2 = pa.Table.from_pandas(df2) 199 | 200 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink: 201 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer: 202 | writer.write_table(df1) 203 | 204 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink: 205 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer: 206 | writer.write_table(df2) 207 | -------------------------------------------------------------------------------- /flm/utils/write_winoground.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | def process(root, iden, row): 11 | text0 = row[0]["caption_0"] 12 | text1 = row[0]["caption_1"] 13 | img0_name = row[0]["image_0"] 14 | img1_name = row[0]["image_1"] 15 | img0_path = f"{root}/data/images/{img0_name}.png" 16 | img1_path = f"{root}/data/images/{img1_name}.png" 17 | # collapsed_tag = row[0]["collapsed_tag"] 18 | with open(img0_path, "rb") as fp: 19 | img0 = fp.read() 20 | with open(img1_path, "rb") as fp: 21 | img1 = fp.read() 22 | 23 | # texts = [r["sentence"] for r in row] 24 | # labels = [r["label"] for r in row] 25 | 26 | # split = iden.split("-")[0] 27 | 28 | # if iden.startswith("train"): 29 | # directory = row[0]["directory"] 30 | # path = f"{root}/images/train/{directory}/{iden}" 31 | # else: 32 | # path = f"{root}/{split}/{iden}" 33 | 34 | # with open(f"{path}-img0.png", "rb") as fp: 35 | # img0 = fp.read() 36 | # with open(f"{path}-img1.png", "rb") as fp: 37 | # img1 = fp.read() 38 | 39 | return [img0, img1, text0, text1, iden] 40 | 41 | 42 | def make_arrow(root, dataset_root): 43 | # train_data = list( 44 | # map(json.loads, open(f"{root}/data/examples.jsonl").readlines()) 45 | # ) 46 | test1_data = list( 47 | map(json.loads, open(f"{root}/data/examples.jsonl").readlines()) 48 | ) 49 | # dev_data = list(map(json.loads, open(f"{root}/nlvr2/data/dev.json").readlines())) 50 | 51 | # balanced_test1_data = list( 52 | # map( 53 | # json.loads, 54 | # open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(), 55 | # ) 56 | # ) 57 | # balanced_dev_data = list( 58 | # map( 59 | # json.loads, 60 | # open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(), 61 | # ) 62 | # ) 63 | 64 | # unbalanced_test1_data = list( 65 | # map( 66 | # json.loads, 67 | # open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(), 68 | # ) 69 | # ) 70 | # unbalanced_dev_data = list( 71 | # map( 72 | # json.loads, 73 | # open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(), 74 | # ) 75 | # ) 76 | splits = ['test'] 77 | datas = [test1_data] 78 | 79 | # splits = [ 80 | # "train", 81 | # "dev", 82 | # "test1", 83 | # "balanced_dev", 84 | # "balanced_test1", 85 | # "unbalanced_dev", 86 | # "unbalanced_test1", 87 | # ] 88 | 89 | # datas = [ 90 | # train_data, 91 | # dev_data, 92 | # test1_data, 93 | # balanced_dev_data, 94 | # balanced_test1_data, 95 | # unbalanced_dev_data, 96 | # unbalanced_test1_data, 97 | # ] 98 | 99 | annotations = dict() 100 | 101 | for split, data in zip(splits, datas): 102 | _annot = defaultdict(list) 103 | for row in tqdm(data): 104 | _annot[row["id"]].append(row) 105 | annotations[split] = _annot 106 | 107 | for split in splits: 108 | bs = [ 109 | process(root, iden, row) for iden, row in tqdm(annotations[split].items()) 110 | ] 111 | 112 | dataframe = pd.DataFrame( 113 | bs, columns=["image_0", "image_1", "text0", "text1", "identifier"], 114 | ) 115 | 116 | table = pa.Table.from_pandas(dataframe) 117 | 118 | os.makedirs(dataset_root, exist_ok=True) 119 | with pa.OSFile(f"{dataset_root}/winoground_{split}.arrow", "wb") as sink: 120 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 121 | writer.write_table(table) 122 | 123 | 124 | make_arrow('/group/30042/wybertwang/dataset/winoground', 125 | '/group/30042/wybertwang/dataset/METER_task_arrow') 126 | -------------------------------------------------------------------------------- /imgs/LMs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/imgs/LMs.png -------------------------------------------------------------------------------- /imgs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/imgs/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.4.0 2 | torch==1.7.1 3 | torchvision==0.8.2 4 | transformers==4.6.0 5 | Pillow==8.1.0 6 | tqdm==4.56.0 7 | ipdb==0.13.4 8 | numpy==1.19.5 9 | einops==0.3.0 10 | pyarrow 11 | sacred==0.8.2 12 | pandas==1.1.5 13 | # timm==0.4.12 14 | timm==0.3.2 15 | ftfy 16 | pycocoevalcap 17 | pycocotools 18 | webdataset 19 | nltk 20 | huggingface_hub 21 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import torch 5 | import pytorch_lightning as pl 6 | from flm.modules import FLMTransformerSS 7 | from flm.datamodules.multitask_datamodule import MTDataModule 8 | from flm.config import ex 9 | 10 | 11 | def args_checker(config): 12 | if config['enable_flm_aux_lm_loss']: 13 | assert config['loss_names']['flm'] > 0 14 | assert config['flm_backbone'] 15 | assert config['is_causal_mask'] 16 | assert config["hidden_size"] == config["hidden_size_for_fusion"], \ 17 | "only support hidden_size_for_fusion=hidden_size" 18 | 19 | 20 | @ex.automain 21 | def run(_config): 22 | config = copy.deepcopy(_config) 23 | args_checker(config) 24 | # print(os.environ) 25 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 26 | rank = int(os.environ.get('RANK', 0)) 27 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 28 | nnodes = int(os.environ.get('NNODES', 1)) 29 | config["world_size"] = world_size 30 | config["rank"] = rank 31 | config["nnodes"] = nnodes 32 | config["num_nodes"] = nnodes 33 | config["local_rank"] = local_rank 34 | 35 | device = torch.device(f'cuda:{local_rank}') 36 | torch.cuda.set_device(device) 37 | 38 | pl.seed_everything(config["seed"]) 39 | dm = MTDataModule(config, dist=True) 40 | exp_name = f'{config["exp_name"]}' 41 | 42 | os.makedirs(config["log_dir"], exist_ok=True) 43 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 44 | dirpath=None, # use logger's path 45 | save_top_k=config["ckpt_save_top_k"], 46 | verbose=True, 47 | monitor="val/the_metric", 48 | mode="max", 49 | save_last=True, 50 | filename='epoch_{epoch:0>3d}-step_{step:0>6d}-val_score_{val/the_metric:.3f}', 51 | auto_insert_metric_name=False, 52 | ) 53 | 54 | version = 0 if config['fix_exp_version'] else None 55 | 56 | logger = pl.loggers.TensorBoardLogger( 57 | config["log_dir"], 58 | name=f'{exp_name}_seed{config["seed"]}_from_{config["load_path"].split("/")[-1][:-5]}', 59 | version=version, 60 | ) 61 | config['exp_path'] = logger.root_dir 62 | 63 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 64 | callbacks = [checkpoint_callback, lr_callback] 65 | 66 | num_gpus = ( 67 | config["num_gpus"] 68 | if isinstance(config["num_gpus"], int) 69 | else len(config["num_gpus"]) 70 | ) 71 | 72 | print(config) 73 | available_batch_size = config["per_gpu_batchsize"] * \ 74 | num_gpus * config["num_nodes"] 75 | grad_steps = max(config["batch_size"] // (available_batch_size), 1) 76 | 77 | max_steps = config["max_steps"] if config["max_steps"] is not None else None 78 | 79 | if local_rank == 0: 80 | # print(os.environ) 81 | print( 82 | f' Node Num: {num_gpus}, Total GPU Numbers: {num_gpus * config["num_nodes"]}') 83 | print( 84 | f' Total Batch Size: {config["batch_size"]}, \ 85 | Available Batch Size: {available_batch_size}, \ 86 | Per GPU Batch Size: {config["per_gpu_batchsize"]},\ 87 | Grad Steps: {grad_steps}') 88 | print(f' Resume_from: {config["resume_from"]}') 89 | print(f' Load_path: {config["load_path"]}') 90 | print(' All configs: \n', json.dumps( 91 | _config, sort_keys=True, indent=4, separators=(',', ':'))) 92 | 93 | model = FLMTransformerSS(config) 94 | 95 | trainer = pl.Trainer( 96 | gpus=config["num_gpus"], 97 | num_nodes=config["num_nodes"], 98 | precision=config["precision"], 99 | accelerator="ddp", 100 | benchmark=True, 101 | deterministic=True, 102 | max_epochs=config["max_epoch"] if max_steps is None else 1000, 103 | max_steps=max_steps, 104 | callbacks=callbacks, 105 | logger=logger, 106 | prepare_data_per_node=config["prepare_data_per_node"], 107 | replace_sampler_ddp=False, 108 | accumulate_grad_batches=grad_steps, 109 | log_every_n_steps=100, 110 | flush_logs_every_n_steps=100, 111 | resume_from_checkpoint=config["resume_from"], 112 | weights_summary="top", 113 | fast_dev_run=config["fast_dev_run"], 114 | val_check_interval=config["val_check_interval"], 115 | # progress_bar_refresh_rate= 5 if config['debug'] else 200, 116 | num_sanity_val_steps=config['num_sanity_val_steps'], 117 | ) 118 | 119 | if not config["test_only"]: 120 | trainer.fit(model, datamodule=dm) 121 | else: 122 | trainer.test(model, datamodule=dm) 123 | --------------------------------------------------------------------------------