├── .gitignore
├── .gitmodules
├── README.md
├── data
└── coco_caption
│ └── captions_val2014.json
├── flm
├── __init__.py
├── config.py
├── datamodules
│ ├── __init__.py
│ ├── coco_caption_karpathy_datamodule.py
│ ├── conceptual_caption12m_datamodule.py
│ ├── conceptual_caption_datamodule.py
│ ├── datamodule_base.py
│ ├── f30k_caption_karpathy_datamodule.py
│ ├── laion100m_datamodule.py
│ ├── laion_datamodule.py
│ ├── multitask_datamodule.py
│ ├── nlvr2_datamodule.py
│ ├── sbu_datamodule.py
│ ├── snli_datamodule.py
│ ├── vg_caption_datamodule.py
│ └── vqav2_datamodule.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── coco_caption_karpathy_dataset.py
│ ├── conceptual_caption12m_dataset.py
│ ├── conceptual_caption_dataset.py
│ ├── f30k_caption_karpathy_dataset.py
│ ├── laion100m_dataset.py
│ ├── laion_dataset.py
│ ├── nlvr2_dataset.py
│ ├── sbu_caption_dataset.py
│ ├── snli_dataset.py
│ ├── vg_caption_dataset.py
│ └── vqav2_dataset.py
├── gadgets
│ ├── __init__.py
│ └── my_metrics.py
├── modules
│ ├── __init__.py
│ ├── bert_model.py
│ ├── clip_model.py
│ ├── dist_utils.py
│ ├── flm_module.py
│ ├── flm_tools.py
│ ├── heads.py
│ ├── meter_utils.py
│ └── objectives.py
├── transforms
│ ├── __init__.py
│ ├── randaug.py
│ ├── transform.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── find_newest_ckpt.py
│ ├── glossary.py
│ ├── utils.py
│ ├── whole_word_masking.py
│ ├── write_coco_karpathy.py
│ ├── write_conceptual_caption.py
│ ├── write_conceptual_caption12M_cloud.py
│ ├── write_conceptual_caption_cloud.py
│ ├── write_f30k_karpathy.py
│ ├── write_nlvr2.py
│ ├── write_sbu.py
│ ├── write_snli.py
│ ├── write_vg.py
│ ├── write_vqa.py
│ └── write_winoground.py
├── imgs
├── LMs.png
└── pipeline.png
├── requirements.txt
└── run.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "flm/pycocoevalcap"]
2 | path = flm/pycocoevalcap
3 | url = https://github.com/salaniz/pycocoevalcap.git
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FLM
2 | Official code for "Accelerating Vision-Language Pretraining with Free Language Modeling" (CVPR 2023)
3 |
4 | Paper: https://arxiv.org/abs/2303.14038
5 |
6 |
7 | ## Introduction
8 |
9 |
10 | 
11 | The state of the arts in vision-language pretraining (VLP) achieves exemplary performance but suffers from high training costs resulting from slow convergence and long training time, especially on large-scale web datasets. An essential obstacle to training efficiency lies in the entangled prediction rate (percentage of tokens for reconstruction) and corruption rate (percentage of corrupted tokens) in masked language modeling (MLM), that is, a proper corruption rate is achieved at the cost of a large portion of output tokens being excluded from prediction loss.
12 |
13 | Free language modeling (FLM) is a new language modeling method that enables a 100% prediction rate with arbitrary corruption rates. FLM successfully frees the prediction rate from the tie-up with the corruption rate while allowing the corruption spans to be customized for each token to be predicted. FLM-trained models are encouraged to learn better and faster given the same GPU time by exploiting bidirectional contexts more flexibly.
14 |
15 |
16 |
17 |
18 |
19 | ## Install
20 | ```
21 | pip install -r requirements.txt
22 | ```
23 | ## Dataset Preparation
24 | We follow [ViLT](https://github.com/dandelin/ViLT) and use `pyarrow` to serialize the datasets. See this [link](https://github.com/dandelin/ViLT/blob/master/DATA.md) for details.
25 |
26 | ## Pretraining
27 | ```bash
28 | export MASTER_ADDR=$DIST_0_IP
29 | export MASTER_PORT=$DIST_0_PORT
30 | export NODE_RANK=$DIST_RANK
31 |
32 | python run.py with data_root= exp_name="pretrain_FLM_4m" \
33 | num_gpus=8 resume_from=None fix_exp_version=True \
34 | flm text_roberta image_size=288 clip32 causal_flm \
35 | precision=16 max_steps=30000 learning_rate=0.00008 \
36 | batch_size=4096 per_gpu_batchsize=64 warmup_steps=0.05
37 | ```
38 | #### Pretrained Checkpoints
39 | FLM-CLIP32-RoBERTa (resolution: 288^2) pre-trained on GCC+SBU+COCO+VG [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4m.ckpt)
40 |
41 | FLM-CLIP32-RoBERTa fintuned on VQAv2 (resolution: 576^2) [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4M_ft_vqa.ckpt)
42 |
43 | FLM-CLIP32-RoBERTa fintuned on NLVR2 (resolution: 288^2) [link](https://github.com/TencentARC/FLM/releases/download/checkpoints/pretrain_4m_ft_nlvr2.ckpt)
44 |
45 | ## Evaluation on Downstream Tasks
46 | #### Visual Question Answering (VQA v2)
47 | ```bash
48 | # training: 4 gpu
49 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_vqa_train" \
50 | num_gpus=4 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \
51 | ft_vqa text_roberta image_size=576 clip32 causal_flm \
52 | learning_rate=0.000005 batch_size=512 per_gpu_batchsize=32 log_dir='result_ft' clip_randaug
53 |
54 | # testing: 4 gpu
55 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_vqa_test" \
56 | num_gpus=4 load_path="pretrain_4M_ft_vqa.ckpt" \
57 | ft_vqa text_roberta image_size=576 clip32 causal_flm \
58 | per_gpu_batchsize=32 log_dir='result_ft' test_only=True skip_test_step=True
59 | ```
60 |
61 | #### Natural Language for Visual Reasoning
62 | ```bash
63 | # training: 1 gpu
64 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_nlvr2_train" \
65 | num_gpus=1 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \
66 | ft_nlvr2 text_roberta image_size=288 clip32 causal_flm \
67 | learning_rate=0.00001 batch_size=256 per_gpu_batchsize=32 log_dir='result_ft' clip_randaug
68 |
69 | # testing: 1 gpu
70 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_nlvr2_test" \
71 | num_gpus=1 load_path="pretrain_4M_ft_nlvr2.ckpt" \
72 | ft_nlvr2 text_roberta image_size=288 clip32 causal_flm \
73 | per_gpu_batchsize=32 log_dir='result_ft' test_only=True skip_test_step=True
74 | ```
75 |
76 | #### Image Captioning
77 | ```bash
78 | # training: 4 gpu
79 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_cap_coco_train" \
80 | num_gpus=4 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \
81 | ft_cap_coco text_roberta image_size=288 clip32 causal_flm \
82 | learning_rate=0.000003 batch_size=256 per_gpu_batchsize=64 log_dir='result_ft' clip_randaug
83 |
84 | # testing: 4 gpu
85 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_cap_coco_test" \
86 | num_gpus=4 load_path="pretrain_4M_ft_cap.ckpt" \
87 | ft_cap_coco text_roberta image_size=384 clip32 causal_flm \
88 | per_gpu_batchsize=64 log_dir='result_ft' test_only=True skip_test_step=True
89 | ```
90 |
91 | #### Image-Text Retrieval
92 |
93 | ```bash
94 | # training: 8 gpu
95 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_irtr_f30k_train" \
96 | num_gpus=8 resume_from=None fix_exp_version=True load_path="pretrain_4m.ckpt" \
97 | ft_irtr_f30k text_roberta image_size=384 clip32 causal_flm precision=16 \
98 | learning_rate=0.000005 batch_size=512 per_gpu_batchsize=8 log_dir='result_ft' clip_randaug
99 |
100 | # testing: 8 gpu
101 | python run.py with data_root= exp_name="pretrain_FLM_4m_ft_irtr_f30k_test" \
102 | num_gpus=8 load_path="pretrain_4M_ft_irtr_f30k.ckpt" \
103 | ft_irtr_f30k text_roberta image_size=384 clip32 causal_flm \
104 | per_gpu_batchsize=8 log_dir='result_ft' test_only=True skip_test_step=True
105 | ```
106 |
107 |
108 | ## Citation
109 | ```
110 | @misc{wang2023accelerating,
111 | title={Accelerating Vision-Language Pretraining with Free Language Modeling},
112 | author={Teng Wang and Yixiao Ge and Feng Zheng and Ran Cheng and Ying Shan and Xiaohu Qie and Ping Luo},
113 | year={2023},
114 | eprint={2303.14038},
115 | archivePrefix={arXiv},
116 | primaryClass={cs.CV}
117 | }
118 | ```
119 |
120 | ## Acknowledgements
121 | The code is highly based on [METER](https://github.com/zdou0830/METER) and [ViLT](https://github.com/dandelin/ViLT).
--------------------------------------------------------------------------------
/flm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/__init__.py
--------------------------------------------------------------------------------
/flm/config.py:
--------------------------------------------------------------------------------
1 | from sacred import Experiment
2 |
3 | ex = Experiment("FLM")
4 |
5 |
6 | def _loss_names(d):
7 | ret = {
8 | "itm": 0,
9 | "mlm": 0, # used for pretraining MLM-based models
10 | "ar": 0, # used for pretraining AR-based models or finetuning on captioning tasks
11 | "flm": 0, # used for pretraining FLM-based models
12 | "vqa": 0,
13 | "nlvr2": 0,
14 | "irtr": 0,
15 | }
16 | ret.update(d)
17 | return ret
18 |
19 |
20 | @ex.config
21 | def config():
22 | only_use_cls_for_flm = False
23 |
24 | debug = False
25 | log_path = ""
26 | is_causal_mask = False
27 |
28 | causal_mask_w_post_cls = False
29 | get_caption_metric = False
30 | get_mlm_caption_metric = False
31 | get_cl_recall_metric = False
32 | get_cl_itm_recall_metric = False
33 |
34 | skip_test_step = False
35 |
36 | flm_backbone = False
37 | temperature = 0.05
38 | random_flm_mask = False
39 | disable_flm_shuffle = False
40 | flm_mask_prob = 0.
41 | text_encoder_from_scratch = False
42 | full_att_mask_for_eval = False
43 | full_att_mask = False
44 | enable_flm_aux_lm_loss = False
45 | flm_aux_lm_loss_l2r_weight = 1.0
46 | flm_aux_lm_loss_r2l_weight = 1.0
47 |
48 | span_corruption_rate = 0
49 |
50 | share_lm_scorer_weights = True
51 |
52 | max_dataset_len = -1
53 |
54 | hidden_size_for_fusion = 768
55 |
56 | caption_prompt = None
57 | add_new_bos_token = False
58 | prepend_bos_token = False
59 | append_eos_token = False
60 |
61 | # webdataset
62 | allow_val_webdataset = False
63 |
64 | # adaptive top bottom layer number for flm
65 | num_reconstructor_bottom_layer = 6
66 | num_reconstructor_top_layer = 6
67 | num_bottom_layer = 6
68 |
69 | # enable_prefix_LM=False
70 | prefix_lm_alpha = 1.0
71 | flm_prediction_rate = 1.0
72 |
73 | # exp name
74 | exp_name = "flm"
75 | seed = 2022
76 | datasets = ["coco", "vg", "sbu", "gcc"]
77 | loss_names = _loss_names({"itm": 1, "mlm": 1})
78 | # hloss_weights = _hloss_weights({'lmcl': 0.1})
79 | # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller.
80 | batch_size = 4096
81 |
82 | prepare_data_per_node = True
83 | # Image setting
84 | train_transform_keys = ["clip"]
85 | val_transform_keys = ["clip"]
86 | image_size = 224
87 | patch_size = 32
88 | draw_false_image = 1
89 | image_only = False
90 | resolution_before = 224
91 |
92 | # Text Setting
93 | vqav2_label_size = 3129
94 | max_text_len = 50
95 | tokenizer = ".cache/bert-base-uncased"
96 | vocab_size = 30522
97 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa
98 | mlm_prob = 0.15
99 | draw_false_text = 0
100 |
101 | # Transformer Setting
102 | num_top_layer = 6
103 | input_image_embed_size = 768
104 | input_text_embed_size = 768
105 | vit = 'ViT-B/32'
106 | hidden_size = 768
107 | num_heads = 12
108 | num_heads_fusion = 12
109 | num_layers = 6
110 | mlp_ratio = 4
111 | drop_rate = 0.1
112 | # truncate_bottom_text_encoder_layer = False
113 |
114 | # Optimizer Setting
115 | optim_type = "adamw"
116 | learning_rate = 1e-5
117 | weight_decay = 0.01
118 | decay_power = 1
119 | max_epoch = 100
120 | max_steps = 100000
121 | warmup_steps = 10000
122 | end_lr = 0
123 | lr_mult_head = 5 # multiply lr for downstream heads
124 | lr_mult_cross_modal = 5 # multiply lr for the cross-modal module
125 |
126 | # Downstream Setting
127 | get_recall_metric = False
128 |
129 | # PL Trainer Setting
130 | resume_from = None
131 | fast_dev_run = False
132 | val_check_interval = 0.2
133 | num_sanity_val_steps = 2
134 | test_only = False
135 | ckpt_save_top_k = 1
136 |
137 | # below params varies with the environment
138 | data_root = ""
139 | log_dir = "result"
140 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=#
141 | num_gpus = 8
142 | # num_nodes = 1
143 | load_path = ""
144 | fix_exp_version = False
145 | num_workers = 8
146 | precision = 32
147 |
148 |
149 | @ex.named_config
150 | def causal_flm():
151 | is_causal_mask = True
152 | causal_mask_w_post_cls = True
153 | flm_backbone = True
154 |
155 |
156 | @ex.named_config
157 | def causal_lm():
158 | is_causal_mask = True
159 | causal_mask_w_post_cls = True
160 |
161 |
162 | @ex.named_config
163 | def mlm():
164 | exp_name = "mlm"
165 | # datasets = ["gcc"]
166 | loss_names = _loss_names({"mlm": 1})
167 | batch_size = 4096
168 | max_epoch = 10
169 | max_steps = 100000
170 | warmup_steps = 0.1
171 | whole_word_masking = True
172 |
173 |
174 | @ex.named_config
175 | def ar():
176 | exp_name = "ar"
177 | # datasets = ["gcc"]
178 | loss_names = _loss_names({"ar": 1})
179 | batch_size = 4096
180 | max_epoch = 10
181 | max_steps = 100000
182 | warmup_steps = 0.1
183 | whole_word_masking = True
184 |
185 |
186 | @ex.named_config
187 | def flm():
188 | exp_name = "flm"
189 | # datasets = ["gcc"]
190 | loss_names = _loss_names({"flm": 1})
191 | batch_size = 4096
192 | max_epoch = 10
193 | max_steps = 100000
194 | warmup_steps = 0.1
195 | whole_word_masking = True
196 |
197 | is_causal_mask = True
198 | causal_mask_w_post_cls = True
199 | # disable_cross_modal_image_layer=True
200 | # cross_modal_layer='text_only'
201 | flm_backbone = True
202 | enable_flm_aux_lm_loss = True
203 |
204 |
205 | @ex.named_config
206 | def flm_itm():
207 | exp_name = "flm_itm"
208 | # datasets = ["gcc"]
209 | loss_names = _loss_names({"flm": 1, "itm": 1})
210 | batch_size = 4096
211 | max_epoch = 10
212 | max_steps = 100000
213 | warmup_steps = 0.1
214 | whole_word_masking = True
215 | enable_flm_aux_lm_loss = True
216 |
217 |
218 | @ex.named_config
219 | def ft_nlvr2():
220 | exp_name = "finetune_nlvr2"
221 | datasets = ["nlvr2"]
222 | loss_names = _loss_names({"nlvr2": 1})
223 | batch_size = 256
224 | max_epoch = 10
225 | max_steps = None
226 | warmup_steps = 0.1
227 | draw_false_image = 0
228 | learning_rate = 1e-5
229 | lr_mult_head = 10
230 | lr_mult_cross_modal = 5
231 | tokenizer = ".cache/bert-base-uncased"
232 | max_text_len = 50
233 | input_text_embed_size = 768
234 | vit = 'ViT-B/32'
235 | train_transform_keys = ["clip"]
236 | val_transform_keys = ["clip"]
237 | input_image_embed_size = 768
238 | image_size = 288
239 |
240 |
241 | @ex.named_config
242 | def ft_vqa():
243 | exp_name = "finetune_vqa"
244 | datasets = ["vqa"]
245 | loss_names = _loss_names({"vqa": 1})
246 | batch_size = 512
247 | max_epoch = 10
248 | max_steps = None
249 | warmup_steps = 0.1
250 | draw_false_image = 0
251 | learning_rate = 5e-6
252 | val_check_interval = 0.5
253 | lr_mult_head = 50
254 | lr_mult_cross_modal = 5
255 | tokenizer = ".cache/bert-base-uncased"
256 | max_text_len = 50
257 | input_text_embed_size = 768
258 | vit = 'ViT-B/32'
259 | train_transform_keys = ["clip"]
260 | val_transform_keys = ["clip"]
261 | input_image_embed_size = 768
262 | image_size = 576
263 |
264 |
265 | @ex.named_config
266 | def ft_irtr_coco():
267 | exp_name = "finetune_irtr_coco"
268 | datasets = ["coco"]
269 | loss_names = _loss_names({"itm": 0.5, "irtr": 1})
270 | batch_size = 512
271 | max_epoch = 10
272 | max_steps = None
273 | warmup_steps = 0.1
274 | get_recall_metric = True
275 | draw_false_text = 15
276 | learning_rate = 5e-6
277 | lr_mult_head = 5
278 | lr_mult_cross_modal = 5
279 | tokenizer = ".cache/bert-base-uncased"
280 | input_text_embed_size = 768
281 | vit = 'ViT-B/32'
282 | train_transform_keys = ["clip"]
283 | val_transform_keys = ["clip"]
284 | input_image_embed_size = 768
285 | image_size = 384
286 |
287 |
288 | @ex.named_config
289 | def ft_cap_coco():
290 | exp_name = "finetune_caption_coco"
291 |
292 | loss_names = _loss_names({"ar": 0.5})
293 | batch_size = 256
294 | max_epoch = 20
295 | max_steps = None
296 | warmup_steps = 0.1
297 | get_caption_metric = True
298 | get_mlm_caption_metric = False
299 | get_recall_metric = False
300 | draw_false_text = 0
301 | learning_rate = 3e-5
302 | lr_mult_head = 5
303 | lr_mult_cross_modal = 5
304 | tokenizer = ".cache/bert-base-uncased"
305 | input_text_embed_size = 768
306 | vit = 'ViT-B/32'
307 | train_transform_keys = ["clip"]
308 | val_transform_keys = ["clip"]
309 | input_image_embed_size = 768
310 | image_size = 384
311 |
312 | caption_prompt = ''
313 | add_new_bos_token = True
314 | prepend_bos_token = True
315 | append_eos_token = True
316 | datasets = ["coco"]
317 | per_gpu_batchsize = 64
318 |
319 |
320 | # @ex.named_config
321 | # def add_bos_eos_tokens():
322 | # add_new_bos_token=True
323 | # prepend_bos_token=True
324 | # append_eos_token=True
325 |
326 | @ex.named_config
327 | def zs_irtr_coco():
328 | test_only = True
329 | skip_test_step = True
330 | get_recall_metric = True
331 | get_cl_recall_metric = False
332 |
333 | exp_name = "zs_irtr_coco"
334 | datasets = ["coco"]
335 | loss_names = _loss_names({"itm": 0.5, "irtr": 1})
336 | batch_size = 512
337 | max_epoch = 10
338 | max_steps = None
339 | warmup_steps = 0.1
340 | get_recall_metric = True
341 | draw_false_text = 15
342 | learning_rate = 5e-6
343 | lr_mult_head = 5
344 | lr_mult_cross_modal = 5
345 | tokenizer = ".cache/bert-base-uncased"
346 | input_text_embed_size = 768
347 | vit = 'ViT-B/32'
348 | train_transform_keys = ["clip"]
349 | val_transform_keys = ["clip"]
350 | input_image_embed_size = 768
351 | image_size = 384
352 |
353 |
354 | @ex.named_config
355 | def ft_irtr_f30k():
356 | exp_name = "finetune_irtr_f30k"
357 | datasets = ["f30k"]
358 | loss_names = _loss_names({"itm": 0.5, "irtr": 1})
359 | batch_size = 512
360 | max_epoch = 10
361 | max_steps = None
362 | warmup_steps = 0.1
363 | get_recall_metric = True
364 | draw_false_text = 15
365 | learning_rate = 5e-6
366 | lr_mult_head = 5
367 | lr_mult_cross_modal = 5
368 | tokenizer = ".cache/bert-base-uncased"
369 | input_text_embed_size = 768
370 | vit = 'ViT-B/32'
371 | train_transform_keys = ["clip"]
372 | val_transform_keys = ["clip"]
373 | input_image_embed_size = 768
374 | image_size = 384
375 |
376 |
377 | @ex.named_config
378 | def ft_cl_itm_irtr_f30k():
379 | exp_name = "finetune_irtr_f30k"
380 | datasets = ["f30k"]
381 | loss_names = _loss_names({"itm": 0.5, "irtr": 1, "cl": 1})
382 | batch_size = 512
383 | max_epoch = 10
384 | max_steps = None
385 | warmup_steps = 0.1
386 | get_recall_metric = False
387 | get_cl_itm_recall_metric = True
388 | draw_false_text = 15
389 | learning_rate = 5e-6
390 | lr_mult_head = 5
391 | lr_mult_cross_modal = 5
392 | tokenizer = ".cache/bert-base-uncased"
393 | input_text_embed_size = 768
394 | vit = 'ViT-B/32'
395 | train_transform_keys = ["clip"]
396 | val_transform_keys = ["clip"]
397 | input_image_embed_size = 768
398 | image_size = 384
399 |
400 |
401 | @ex.named_config
402 | def zs_irtr_f30k():
403 | test_only = True
404 | skip_test_step = True
405 | get_recall_metric = True
406 | get_cl_recall_metric = False
407 |
408 | exp_name = "zeroshot_irtr_f30k"
409 | datasets = ["f30k"]
410 | loss_names = _loss_names({"itm": 0.5, "irtr": 1})
411 | batch_size = 512
412 | max_epoch = 10
413 | max_steps = None
414 | warmup_steps = 0.1
415 | get_recall_metric = True
416 | draw_false_text = 15
417 | learning_rate = 5e-6
418 | lr_mult_head = 5
419 | lr_mult_cross_modal = 5
420 | tokenizer = ".cache/bert-base-uncased"
421 | input_text_embed_size = 768
422 | vit = 'ViT-B/32'
423 | train_transform_keys = ["clip"]
424 | val_transform_keys = ["clip"]
425 | input_image_embed_size = 768
426 | image_size = 384
427 |
428 |
429 | @ex.named_config
430 | def ft_cl_irtr_f30k():
431 | exp_name = "finetune_cl_irtr_f30k"
432 | datasets = ["f30k"]
433 | loss_names = _loss_names({"cl": 1.0})
434 | batch_size = 512
435 | max_epoch = 10
436 | max_steps = None
437 | warmup_steps = 0.1
438 | get_recall_metric = False
439 | get_cl_recall_metric = True
440 | draw_false_text = 15
441 | learning_rate = 5e-6
442 | lr_mult_head = 5
443 | lr_mult_cross_modal = 5
444 | tokenizer = ".cache/bert-base-uncased"
445 | input_text_embed_size = 768
446 | vit = 'ViT-B/32'
447 | train_transform_keys = ["clip"]
448 | val_transform_keys = ["clip"]
449 | input_image_embed_size = 768
450 | image_size = 384
451 |
452 |
453 | @ex.named_config
454 | def zs_cl_irtr_f30k():
455 | test_only = True
456 | skip_test_step = True
457 | get_recall_metric = False
458 | get_cl_recall_metric = True
459 |
460 | exp_name = "zs_cl_irtr_f30k"
461 | datasets = ["f30k"]
462 | loss_names = _loss_names({"cl": 1.0})
463 | batch_size = 512
464 | max_epoch = 10
465 | max_steps = None
466 | warmup_steps = 0.1
467 | get_recall_metric = False
468 | get_cl_recall_metric = True
469 | draw_false_text = 15
470 | learning_rate = 5e-6
471 | lr_mult_head = 5
472 | lr_mult_cross_modal = 5
473 | tokenizer = ".cache/bert-base-uncased"
474 | input_text_embed_size = 768
475 | vit = 'ViT-B/32'
476 | train_transform_keys = ["clip"]
477 | val_transform_keys = ["clip"]
478 | input_image_embed_size = 768
479 | image_size = 384
480 |
481 |
482 | @ex.named_config
483 | def zs_cl_irtr_coco():
484 | test_only = True
485 | skip_test_step = True
486 | get_recall_metric = False
487 | get_cl_recall_metric = True
488 |
489 | exp_name = "zs_cl_irtr_coco"
490 | datasets = ["coco"]
491 | loss_names = _loss_names({"cl": 0.5})
492 | batch_size = 512
493 | max_epoch = 10
494 | max_steps = None
495 | warmup_steps = 0.1
496 | get_recall_metric = False
497 | get_cl_recall_metric = True
498 | draw_false_text = 15
499 | learning_rate = 5e-6
500 | lr_mult_head = 5
501 | lr_mult_cross_modal = 5
502 | tokenizer = ".cache/bert-base-uncased"
503 | input_text_embed_size = 768
504 | vit = 'ViT-B/32'
505 | train_transform_keys = ["clip"]
506 | val_transform_keys = ["clip"]
507 | input_image_embed_size = 768
508 | image_size = 384
509 |
510 |
511 | @ex.named_config
512 | def ft_snli_clip_bert():
513 | exp_name = "finetune_snli"
514 | datasets = ["snli"]
515 | loss_names = _loss_names({"snli": 1})
516 | batch_size = 64
517 | max_epoch = 5
518 | max_steps = None
519 | warmup_steps = 0.1
520 | draw_false_image = 0
521 | learning_rate = 2e-6
522 | lr_mult_head = 10
523 | lr_mult_cross_modal = 5
524 | tokenizer = ".cache/bert-base-uncased"
525 | max_text_len = 50
526 | input_text_embed_size = 768
527 | vit = 'ViT-B/32'
528 | train_transform_keys = ["clip"]
529 | val_transform_keys = ["clip"]
530 | input_image_embed_size = 768
531 | image_size = 384
532 |
533 |
534 | # Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end
535 |
536 | # vision encoder
537 | @ex.named_config
538 | def swin32_base224():
539 | vit = "swin_base_patch4_window7_224_in22k"
540 | patch_size = 32
541 | image_size = 224
542 | train_transform_keys = ["imagenet"]
543 | val_transform_keys = ["imagenet"]
544 | input_image_embed_size = 1024
545 | resolution_before = 224
546 |
547 |
548 | @ex.named_config
549 | def swin32_base384():
550 | vit = "swin_base_patch4_window12_384_in22k"
551 | patch_size = 32
552 | image_size = 384
553 | train_transform_keys = ["imagenet"]
554 | val_transform_keys = ["imagenet"]
555 | input_image_embed_size = 1024
556 | resolution_before = 384
557 |
558 |
559 | @ex.named_config
560 | def swin32_large384():
561 | vit = "swin_large_patch4_window12_384_in22k"
562 | patch_size = 32
563 | image_size = 384
564 | train_transform_keys = ["imagenet"]
565 | val_transform_keys = ["imagenet"]
566 | input_image_embed_size = 1536
567 | resolution_before = 384
568 |
569 |
570 | @ex.named_config
571 | def clip32():
572 | vit = 'ViT-B/32'
573 | patch_size = 32
574 | train_transform_keys = ["clip"]
575 | val_transform_keys = ["clip"]
576 | input_image_embed_size = 768
577 |
578 |
579 | @ex.named_config
580 | def clip16():
581 | vit = 'ViT-B/16'
582 | patch_size = 16
583 | train_transform_keys = ["clip"]
584 | val_transform_keys = ["clip"]
585 | input_image_embed_size = 768
586 |
587 |
588 | @ex.named_config
589 | def clip14():
590 | vit = 'ViT-L/14'
591 | patch_size = 14
592 | train_transform_keys = ["clip"]
593 | val_transform_keys = ["clip"]
594 | input_image_embed_size = 1024
595 |
596 |
597 | @ex.named_config
598 | def clip14_336():
599 | vit = 'ViT-L/14@336px'
600 | image_size = 336
601 | patch_size = 14
602 | train_transform_keys = ["clip"]
603 | val_transform_keys = ["clip"]
604 | input_image_embed_size = 1024
605 |
606 |
607 | @ex.named_config
608 | def mae_vit_huge_patch14():
609 | vit = 'mae_vit_huge_patch14'
610 | image_size = 224
611 | patch_size = 14
612 | train_transform_keys = ["mae"]
613 | val_transform_keys = ["mae"]
614 |
615 |
616 | @ex.named_config
617 | def mae_vit_large_patch16():
618 | vit = 'mae_vit_large_patch16'
619 | image_size = 224
620 | patch_size = 16
621 | train_transform_keys = ["mae"]
622 | val_transform_keys = ["mae"]
623 |
624 |
625 | @ex.named_config
626 | def mae_vit_base_patch16():
627 | vit = 'mae_vit_base_patch16'
628 | image_size = 224
629 | patch_size = 16
630 | train_transform_keys = ["mae"]
631 | val_transform_keys = ["mae"]
632 |
633 | # text encoder
634 |
635 |
636 | @ex.named_config
637 | def text_roberta():
638 | tokenizer = ".cache/roberta-base"
639 | vocab_size = 50265
640 | input_text_embed_size = 768
641 |
642 |
643 | # @ex.named_config
644 | # def text_clip():
645 | # tokenizer = ".cache/roberta-base"
646 | # vocab_size = 50265
647 | # input_text_embed_size = 768
648 |
649 | @ex.named_config
650 | def text_roberta_large():
651 | tokenizer = ".cache/roberta-large"
652 | vocab_size = 50265
653 | input_text_embed_size = 1024
654 |
655 |
656 | # random augmentation
657 | @ex.named_config
658 | def imagenet_randaug():
659 | train_transform_keys = ["imagenet_randaug"]
660 |
661 |
662 | @ex.named_config
663 | def clip_randaug():
664 | train_transform_keys = ["clip_randaug"]
665 |
666 |
667 | @ex.named_config
668 | def mae_randaug():
669 | train_transform_keys = ["mae_randaug"]
670 |
--------------------------------------------------------------------------------
/flm/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from .vg_caption_datamodule import VisualGenomeCaptionDataModule
2 | from .f30k_caption_karpathy_datamodule import F30KCaptionKarpathyDataModule
3 | from .coco_caption_karpathy_datamodule import CocoCaptionKarpathyDataModule
4 | from .conceptual_caption_datamodule import ConceptualCaptionDataModule
5 | from .sbu_datamodule import SBUCaptionDataModule
6 | from .vqav2_datamodule import VQAv2DataModule
7 | from .nlvr2_datamodule import NLVR2DataModule
8 | from .snli_datamodule import SNLIDataModule
9 | from .conceptual_caption12m_datamodule import ConceptualCaption12mDataModule
10 | # from .conceptual_caption8m_datamodule import ConceptualCaption8mDataModule
11 | from .laion_datamodule import LaionDataModule
12 | from .laion100m_datamodule import Laion100mDataModule
13 | # from .wino_datamodule import WinoDataModule
14 |
15 | _datamodules = {
16 | "vg": VisualGenomeCaptionDataModule,
17 | "f30k": F30KCaptionKarpathyDataModule,
18 | "coco": CocoCaptionKarpathyDataModule,
19 | "gcc": ConceptualCaptionDataModule,
20 | "sbu": SBUCaptionDataModule,
21 | "vqa": VQAv2DataModule,
22 | "nlvr2": NLVR2DataModule,
23 | "snli": SNLIDataModule,
24 | "gcc12m": ConceptualCaption12mDataModule,
25 | # "gcc8m": ConceptualCaption8mDataModule,
26 | "laion": LaionDataModule,
27 | "laion100m": Laion100mDataModule,
28 | # "wino": WinoDataModule
29 | }
30 |
--------------------------------------------------------------------------------
/flm/datamodules/coco_caption_karpathy_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import CocoCaptionKarpathyDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # COCO Caption datamodule
6 | class CocoCaptionKarpathyDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return CocoCaptionKarpathyDataset
13 |
14 | @property
15 | def dataset_cls_no_false(self):
16 | return CocoCaptionKarpathyDataset
17 |
18 | @property
19 | def dataset_name(self):
20 | return "coco"
21 |
--------------------------------------------------------------------------------
/flm/datamodules/conceptual_caption12m_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import ConceptualCaption12mDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # Conceptual Caption 12M datamodule
6 | class ConceptualCaption12mDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return ConceptualCaption12mDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "gcc"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/conceptual_caption_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import ConceptualCaptionDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # Conceptual Caption 3M datamodule
6 | class ConceptualCaptionDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return ConceptualCaptionDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "gcc"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/datamodule_base.py:
--------------------------------------------------------------------------------
1 | from random import shuffle
2 | import torch
3 | import functools
4 | from pytorch_lightning import LightningDataModule
5 | from torch.utils.data import DataLoader
6 | from transformers import (
7 | DataCollatorForLanguageModeling,
8 | # DataCollatorForWholeWordMask,
9 | BertTokenizer,
10 | RobertaTokenizer,
11 | )
12 |
13 | from flm.utils.whole_word_masking import DataCollatorForWholeWordMask
14 |
15 |
16 | class text_preprocessor():
17 | """prepend or append special tokens"""
18 |
19 | def __init__(self, config) -> None:
20 | self.prepend_bos = config['add_new_bos_token'] and config['prepend_bos_token']
21 | self.append_eos = config['add_new_bos_token'] and config['append_eos_token']
22 |
23 | def __call__(self, text):
24 | text = text.rstrip().rstrip('.').rstrip() + '.'
25 | if self.prepend_bos:
26 | text = '' + ' ' + text
27 | if self.append_eos:
28 | text = text + ' ' + ''
29 | return text
30 |
31 |
32 | def flm_collator(attention_mask, mask_ratio, disable_shuffle=True, label_strategy='none'):
33 | """get flm masks and labels"""
34 | text_len = attention_mask.sum(1)
35 | bs, max_len = attention_mask.size()
36 | flm_masks = -10000. * torch.ones(bs, max_len, max_len)
37 | # attention_mask.unsqueeze(dim=2) * attention_mask.unsqueeze(dim=1)
38 | flm_random_ids = []
39 | mask_num = torch.distributions.Binomial(
40 | text_len.float() - 1, mask_ratio).sample().int()
41 | for i in range(len(text_len)):
42 | flm_random_id = torch.randperm(text_len[i] - 1) + 1
43 | flm_random_id = flm_random_id[:text_len[i] - 1 - mask_num[i]]
44 | if disable_shuffle:
45 | flm_random_id = torch.sort(flm_random_id)[0]
46 | flm_random_ids.append(flm_random_id)
47 | # print(flm_random_id)
48 | for j in range(len(flm_random_id)):
49 | if flm_random_id[j] < 0:
50 | break
51 | else:
52 | flm_masks[i,
53 | flm_random_id[j:j + 1].repeat(j+1),
54 | flm_random_id[:j+1]] = 0
55 |
56 | flm_label = None
57 | if label_strategy == 'none':
58 | pass
59 | else:
60 |
61 | if label_strategy == 'object':
62 | pass
63 | elif label_strategy == 'concrete':
64 | pass
65 | return flm_random_ids, flm_masks, flm_label
66 |
67 |
68 | def sep_collator(flatten_encodings, mlm_collator, mask_ratio, pred_corr_ratio) -> None:
69 | if pred_corr_ratio > 1:
70 | repeat_num = int(pred_corr_ratio)
71 | group_mlms = [[] for i in range(repeat_num)]
72 | mlms = mlm_collator(flatten_encodings)
73 | # print('mlms', mlms)
74 | for idx, flatten_encoding in enumerate(flatten_encodings):
75 | token_num = len(flatten_encoding['attention_mask'])
76 | chunk_size = token_num // repeat_num + 1
77 | org_input_id = torch.tensor(flatten_encoding['input_ids'])
78 | mlm_input_id = mlms['input_ids'][idx]
79 | mlm_labels = mlms['labels'][idx]
80 | ava_mask_reg = torch.tensor(flatten_encoding['attention_mask']) * (
81 | 1 - torch.tensor(flatten_encoding['special_tokens_mask']))
82 | perm = torch.randperm(token_num)
83 | groups = perm.split(chunk_size)
84 | assert len(groups) == repeat_num
85 | for i in range(repeat_num):
86 | group_mask = torch.zeros(token_num).long()
87 | group_mask[groups[i]] = 1
88 | group_input_id = org_input_id * \
89 | (1-group_mask) + mlm_input_id * group_mask
90 | group_label = -100 * torch.ones(token_num).long()
91 | group_label[group_mask.bool()] = mlm_labels[group_mask.bool()]
92 | group_mlm = {'input_ids': group_input_id,
93 | 'labels': group_label}
94 | group_mlms[i].append(group_mlm)
95 | # print(group_mask)
96 | for i in range(repeat_num):
97 | group_mlms[i] = {'input_ids': torch.stack([_['input_ids'] for _ in group_mlms[i]]),
98 | 'labels': torch.stack([_['labels'] for _ in group_mlms[i]])}
99 | return group_mlms
100 |
101 | elif pred_corr_ratio < 1:
102 | mlms = mlm_collator(flatten_encodings)
103 | group_labels = []
104 | # print('mlms', mlms)
105 | for idx, flatten_encoding in enumerate(flatten_encodings):
106 | token_num = len(flatten_encoding['attention_mask'])
107 | mlm_input_id = mlms['input_ids'][idx]
108 | mlm_labels = mlms['labels'][idx]
109 | perm = torch.randperm(token_num)[:int(token_num * pred_corr_ratio)]
110 | group_label = -100 * torch.ones(token_num).long()
111 | group_label[perm] = mlm_labels[perm]
112 | group_labels.append(group_label)
113 |
114 | group_mlm = {'input_ids': mlms['input_ids'],
115 | 'labels': torch.stack(group_labels, dim=0)}
116 | return group_mlm
117 |
118 |
119 | def get_pretrained_tokenizer(from_pretrained):
120 | if torch.distributed.is_initialized():
121 | if torch.distributed.get_rank() == 0:
122 | if 'roberta' in from_pretrained:
123 | RobertaTokenizer.from_pretrained(from_pretrained)
124 | else:
125 | BertTokenizer.from_pretrained(
126 | from_pretrained, do_lower_case="uncased" in from_pretrained
127 | )
128 | torch.distributed.barrier()
129 |
130 | if 'roberta' in from_pretrained:
131 | return RobertaTokenizer.from_pretrained(from_pretrained)
132 | elif 'gpt2' in from_pretrained:
133 | from transformers import GPT2Tokenizer, GPT2Model
134 | return GPT2Tokenizer.from_pretrained('gpt2')
135 | return BertTokenizer.from_pretrained(
136 | from_pretrained, do_lower_case="uncased" in from_pretrained
137 | )
138 |
139 |
140 | class BaseDataModule(LightningDataModule):
141 | def __init__(self, _config):
142 | super().__init__()
143 | self.data_dir = _config["data_root"]
144 |
145 | self.num_workers = _config["num_workers"]
146 | self.batch_size = _config["per_gpu_batchsize"]
147 | self.eval_batch_size = self.batch_size
148 |
149 | self.image_size = _config["image_size"]
150 | self.max_text_len = _config["max_text_len"]
151 | self.draw_false_image = _config["draw_false_image"]
152 | self.draw_false_text = _config["draw_false_text"]
153 | self.image_only = _config["image_only"]
154 |
155 | self.train_transform_keys = (
156 | ["default_train"]
157 | if len(_config["train_transform_keys"]) == 0
158 | else _config["train_transform_keys"]
159 | )
160 |
161 | self.val_transform_keys = (
162 | ["default_val"]
163 | if len(_config["val_transform_keys"]) == 0
164 | else _config["val_transform_keys"]
165 | )
166 |
167 | tokenizer = _config["tokenizer"]
168 | self.tokenizer = get_pretrained_tokenizer(tokenizer)
169 | if _config['add_new_bos_token']:
170 | self.tokenizer.add_tokens(['', ''])
171 | self.vocab_size = self.tokenizer.vocab_size
172 |
173 | collator = (
174 | DataCollatorForWholeWordMask
175 | if _config["whole_word_masking"]
176 | else DataCollatorForLanguageModeling
177 | )
178 |
179 | self.mlm_collator = {'mlm_collator':
180 | collator(tokenizer=self.tokenizer,
181 | mlm=True,
182 | mlm_probability=_config["mlm_prob"]),
183 | "flm_collator":
184 | functools.partial(
185 | flm_collator,
186 | mask_ratio=_config["flm_mask_prob"],
187 | disable_shuffle=_config["disable_flm_shuffle"]),
188 | }
189 |
190 | self.text_preprocessor = text_preprocessor(_config)
191 | self.setup_flag = False
192 | self.max_dataset_len = _config.get('max_dataset_len', -1)
193 |
194 | @property
195 | def dataset_cls(self):
196 | raise NotImplementedError("return tuple of dataset class")
197 |
198 | @property
199 | def dataset_name(self):
200 | raise NotImplementedError("return name of dataset")
201 |
202 | def set_train_dataset(self):
203 | self.train_dataset = self.dataset_cls(
204 | self.data_dir,
205 | self.train_transform_keys,
206 | split="train",
207 | image_size=self.image_size,
208 | max_text_len=self.max_text_len,
209 | draw_false_image=self.draw_false_image,
210 | draw_false_text=self.draw_false_text,
211 | image_only=self.image_only,
212 | tokenizer=self.tokenizer,
213 | disable_sep_mlm=False,
214 | text_preprocessor=self.text_preprocessor,
215 | max_dataset_len=self.max_dataset_len
216 | )
217 |
218 | def set_val_dataset(self):
219 | self.val_dataset = self.dataset_cls(
220 | self.data_dir,
221 | self.val_transform_keys,
222 | split="val",
223 | image_size=self.image_size,
224 | max_text_len=self.max_text_len,
225 | draw_false_image=self.draw_false_image,
226 | draw_false_text=self.draw_false_text,
227 | image_only=self.image_only,
228 | tokenizer=self.tokenizer,
229 | text_preprocessor=self.text_preprocessor,
230 | max_dataset_len=self.max_dataset_len
231 | )
232 |
233 | if hasattr(self, "dataset_cls_no_false"):
234 | self.val_dataset_no_false = self.dataset_cls_no_false(
235 | self.data_dir,
236 | self.val_transform_keys,
237 | split="val",
238 | image_size=self.image_size,
239 | max_text_len=self.max_text_len,
240 | draw_false_image=0,
241 | draw_false_text=0,
242 | image_only=self.image_only,
243 | tokenizer=self.tokenizer,
244 | text_preprocessor=self.text_preprocessor,
245 | max_dataset_len=self.max_dataset_len
246 | )
247 |
248 | def make_no_false_val_dset(self, image_only=False):
249 | return self.dataset_cls_no_false(
250 | self.data_dir,
251 | self.val_transform_keys,
252 | split="val",
253 | image_size=self.image_size,
254 | max_text_len=self.max_text_len,
255 | draw_false_image=0,
256 | draw_false_text=0,
257 | image_only=image_only,
258 | tokenizer=self.tokenizer,
259 | text_preprocessor=self.text_preprocessor,
260 | max_dataset_len=self.max_dataset_len
261 | )
262 |
263 | def set_test_dataset(self):
264 | self.test_dataset = self.dataset_cls(
265 | self.data_dir,
266 | self.val_transform_keys,
267 | split="test",
268 | image_size=self.image_size,
269 | max_text_len=self.max_text_len,
270 | draw_false_image=self.draw_false_image,
271 | draw_false_text=self.draw_false_text,
272 | image_only=self.image_only,
273 | tokenizer=self.tokenizer,
274 | text_preprocessor=self.text_preprocessor,
275 | max_dataset_len=self.max_dataset_len
276 | )
277 |
278 | def setup(self, stage):
279 | if not self.setup_flag:
280 | self.set_train_dataset()
281 | self.set_val_dataset()
282 | self.set_test_dataset()
283 |
284 | self.train_dataset.tokenizer = self.tokenizer
285 | self.val_dataset.tokenizer = self.tokenizer
286 | self.test_dataset.tokenizer = self.tokenizer
287 |
288 | self.setup_flag = True
289 |
290 | def train_dataloader(self):
291 | loader = DataLoader(
292 | self.train_dataset,
293 | batch_size=self.batch_size,
294 | shuffle=True,
295 | num_workers=self.num_workers,
296 | pin_memory=True,
297 | collate_fn=self.train_dataset.collate,
298 | )
299 | return loader
300 |
301 | def val_dataloader(self):
302 | loader = DataLoader(
303 | self.val_dataset,
304 | batch_size=self.eval_batch_size,
305 | shuffle=False,
306 | num_workers=self.num_workers,
307 | pin_memory=True,
308 | collate_fn=self.val_dataset.collate,
309 | )
310 | return loader
311 |
312 | def test_dataloader(self):
313 | loader = DataLoader(
314 | self.test_dataset,
315 | batch_size=self.eval_batch_size,
316 | shuffle=False,
317 | num_workers=self.num_workers,
318 | pin_memory=True,
319 | collate_fn=self.test_dataset.collate,
320 | )
321 | return loader
322 |
--------------------------------------------------------------------------------
/flm/datamodules/f30k_caption_karpathy_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import F30KCaptionKarpathyDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # Flickr30K datamodule
6 | class F30KCaptionKarpathyDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return F30KCaptionKarpathyDataset
13 |
14 | @property
15 | def dataset_cls_no_false(self):
16 | return F30KCaptionKarpathyDataset
17 |
18 | @property
19 | def dataset_name(self):
20 | return "f30k"
21 |
22 | def train_dataloader(self):
23 | loader = DataLoader(
24 | self.train_dataset,
25 | batch_size=self.batch_size,
26 | shuffle=True,
27 | num_workers=0,
28 | pin_memory=True,
29 | collate_fn=self.train_dataset.collate,
30 | )
31 | return loader
32 |
33 | def val_dataloader(self):
34 | loader = DataLoader(
35 | self.val_dataset,
36 | batch_size=self.eval_batch_size,
37 | shuffle=False,
38 | num_workers=0,
39 | pin_memory=True,
40 | collate_fn=self.val_dataset.collate,
41 | )
42 | return loader
43 |
44 | def test_dataloader(self):
45 | loader = DataLoader(
46 | self.test_dataset,
47 | batch_size=self.eval_batch_size,
48 | shuffle=False,
49 | num_workers=0,
50 | pin_memory=True,
51 | collate_fn=self.test_dataset.collate,
52 | )
53 | return loader
54 |
--------------------------------------------------------------------------------
/flm/datamodules/laion100m_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import Laion100mDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # LAION-100M datamodule, a random subset of LAION-400M
6 | class Laion100mDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return Laion100mDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "laion"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/laion_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import LaionDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # LAION-400M datamodule
6 | class LaionDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return LaionDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "laion"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/multitask_datamodule.py:
--------------------------------------------------------------------------------
1 | from builtins import hasattr
2 | import functools
3 |
4 | from pytorch_lightning import LightningDataModule
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data.dataset import ConcatDataset
7 | from torch.utils.data.distributed import DistributedSampler
8 |
9 | from . import _datamodules
10 | import webdataset as wds
11 |
12 |
13 | # datamodule for mutiple datasets
14 | class MTDataModule(LightningDataModule):
15 | def __init__(self, _config, dist=False):
16 | datamodule_keys = _config["datasets"]
17 | assert len(datamodule_keys) > 0
18 |
19 | super().__init__()
20 |
21 | self.dm_keys = datamodule_keys
22 | self.dm_dicts = {key: _datamodules[key](
23 | _config) for key in datamodule_keys}
24 | self.dms = [v for k, v in self.dm_dicts.items()]
25 |
26 | self.batch_size = self.dms[0].batch_size
27 | self.vocab_size = self.dms[0].vocab_size
28 | self.num_workers = self.dms[0].num_workers
29 |
30 | self.dist = dist
31 | self.allow_val_webdataset = _config['allow_val_webdataset']
32 |
33 | def prepare_data(self):
34 | for dm in self.dms:
35 | dm.prepare_data()
36 |
37 | def setup(self, stage):
38 | def check_webdataset(dataset):
39 | if hasattr(dataset, 'inner_dataset'):
40 | return True
41 |
42 | for dm in self.dms:
43 | dm.setup(stage)
44 |
45 | if check_webdataset(self.dms[0].train_dataset):
46 | assert len(
47 | self.dms) == 1, 'does not support webdataset instance larger than 1'
48 | self.train_dataset = self.dms[0].train_dataset.inner_dataset
49 | # self.train_dataset.append(wds.batched(self.batch_size))
50 | else:
51 | self.train_dataset = ConcatDataset(
52 | [dm.train_dataset for dm in self.dms])
53 |
54 | if check_webdataset(self.dms[0].val_dataset) and self.allow_val_webdataset:
55 | self.val_dataset = self.dms[0].val_dataset.inner_dataset
56 | # self.val_dataset.append(wds.batched(self.batch_size))
57 | else:
58 | self.val_dataset = ConcatDataset(
59 | [dm.val_dataset for dm in self.dms])
60 |
61 | if check_webdataset(self.dms[0].test_dataset) and self.allow_val_webdataset:
62 | self.test_dataset = self.dms[0].test_dataset.inner_dataset
63 | # self.test_dataset.append(wds.batched(self.batch_size))
64 | else:
65 | self.test_dataset = ConcatDataset(
66 | [dm.test_dataset for dm in self.dms])
67 |
68 | self.tokenizer = self.dms[0].tokenizer
69 |
70 | self.train_collate = functools.partial(
71 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator
72 | )
73 | self.val_collate = functools.partial(
74 | self.dms[0].val_dataset.collate, mlm_collator=self.dms[0].mlm_collator
75 | )
76 | self.test_collate = functools.partial(
77 | self.dms[0].test_dataset.collate, mlm_collator=self.dms[0].mlm_collator
78 | )
79 |
80 | if self.dist:
81 | if isinstance(self.train_dataset, wds.DataPipeline):
82 | self.train_sampler = None
83 | else:
84 | self.train_sampler = DistributedSampler(
85 | self.train_dataset, shuffle=True)
86 | if isinstance(self.val_dataset, wds.DataPipeline) and self.allow_val_webdataset:
87 | self.val_sampler = None
88 | else:
89 | self.val_sampler = DistributedSampler(
90 | self.val_dataset, shuffle=True)
91 | if isinstance(self.test_dataset, wds.DataPipeline) and self.allow_val_webdataset:
92 | self.test_sampler = None
93 | else:
94 | self.test_sampler = DistributedSampler(
95 | self.test_dataset, shuffle=False)
96 |
97 | else:
98 | self.train_sampler = None
99 | self.val_sampler = None
100 | self.test_sampler = None
101 |
102 | def train_dataloader(self):
103 | loader = DataLoader(
104 | self.train_dataset,
105 | batch_size=self.batch_size,
106 | sampler=self.train_sampler,
107 | num_workers=self.num_workers,
108 | collate_fn=self.train_collate,
109 | )
110 | return loader
111 |
112 | def val_dataloader(self, batch_size=None):
113 | loader = DataLoader(
114 | self.val_dataset,
115 | batch_size=batch_size if batch_size is not None else self.batch_size,
116 | sampler=self.val_sampler,
117 | num_workers=self.num_workers,
118 | collate_fn=self.val_collate,
119 | )
120 | return loader
121 |
122 | def test_dataloader(self):
123 | loader = DataLoader(
124 | self.test_dataset,
125 | batch_size=self.batch_size,
126 | sampler=self.test_sampler,
127 | num_workers=self.num_workers,
128 | collate_fn=self.test_collate,
129 | )
130 | return loader
131 |
--------------------------------------------------------------------------------
/flm/datamodules/nlvr2_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import NLVR2Dataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # NLVR2 datamodule
6 | class NLVR2DataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return NLVR2Dataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "nlvr2"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/sbu_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import SBUCaptionDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # SBU Caption datamodule
6 | class SBUCaptionDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return SBUCaptionDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "sbu"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/snli_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import SNLIDataset
2 | from .datamodule_base import BaseDataModule
3 | from collections import defaultdict
4 |
5 |
6 | # SNLI datamodule
7 | class SNLIDataModule(BaseDataModule):
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(*args, **kwargs)
10 |
11 | @property
12 | def dataset_cls(self):
13 | return SNLIDataset
14 |
15 | @property
16 | def dataset_name(self):
17 | return "snli"
18 |
--------------------------------------------------------------------------------
/flm/datamodules/vg_caption_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import VisualGenomeCaptionDataset
2 | from .datamodule_base import BaseDataModule
3 |
4 |
5 | # VisualGenome datamodule
6 | class VisualGenomeCaptionDataModule(BaseDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | @property
11 | def dataset_cls(self):
12 | return VisualGenomeCaptionDataset
13 |
14 | @property
15 | def dataset_name(self):
16 | return "vg"
17 |
--------------------------------------------------------------------------------
/flm/datamodules/vqav2_datamodule.py:
--------------------------------------------------------------------------------
1 | from ..datasets import VQAv2Dataset
2 | from .datamodule_base import BaseDataModule
3 | from collections import defaultdict
4 |
5 |
6 | # VQAv2 datamodule
7 | class VQAv2DataModule(BaseDataModule):
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(*args, **kwargs)
10 |
11 | @property
12 | def dataset_cls(self):
13 | return VQAv2Dataset
14 |
15 | @property
16 | def dataset_name(self):
17 | return "vqa"
18 |
19 | def setup(self, stage):
20 | super().setup(stage)
21 |
22 | train_answers = self.train_dataset.table["answers"].to_pandas(
23 | ).tolist()
24 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist()
25 | train_labels = self.train_dataset.table["answer_labels"].to_pandas(
26 | ).tolist()
27 | val_labels = self.val_dataset.table["answer_labels"].to_pandas(
28 | ).tolist()
29 |
30 | all_answers = [c for c in train_answers + val_answers if c is not None]
31 | all_answers = [l for lll in all_answers for ll in lll for l in ll]
32 | all_labels = [c for c in train_labels + val_labels if c is not None]
33 | all_labels = [l for lll in all_labels for ll in lll for l in ll]
34 |
35 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)}
36 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1])
37 | self.num_class = max(self.answer2id.values()) + 1
38 |
39 | self.id2answer = defaultdict(lambda: "unknown")
40 | for k, v in sorted_a2i:
41 | self.id2answer[v] = k
42 |
--------------------------------------------------------------------------------
/flm/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .vg_caption_dataset import VisualGenomeCaptionDataset
2 | from .coco_caption_karpathy_dataset import CocoCaptionKarpathyDataset
3 | from .f30k_caption_karpathy_dataset import F30KCaptionKarpathyDataset
4 | from .conceptual_caption_dataset import ConceptualCaptionDataset
5 | from .conceptual_caption12m_dataset import ConceptualCaption12mDataset
6 | from .sbu_caption_dataset import SBUCaptionDataset
7 | from .vqav2_dataset import VQAv2Dataset
8 | from .nlvr2_dataset import NLVR2Dataset
9 | from .snli_dataset import SNLIDataset
10 | from .laion_dataset import LaionDataset
11 | from .laion100m_dataset import Laion100mDataset
12 |
--------------------------------------------------------------------------------
/flm/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import io
4 | import pyarrow as pa
5 | import os
6 | import pdb
7 | from PIL import Image
8 | from ..transforms import keys_to_transforms
9 | import pdb
10 | import copy
11 |
12 |
13 | class BaseDataset(torch.utils.data.Dataset):
14 | def __init__(
15 | self,
16 | data_dir: str,
17 | transform_keys: list,
18 | image_size: int,
19 | names: list,
20 | text_column_name: str = "",
21 | remove_duplicate=True,
22 | max_text_len=40,
23 | max_dataset_len=-1,
24 | draw_false_image=0,
25 | draw_false_text=0,
26 | image_only=False,
27 | tokenizer=None,
28 | disable_sep_mlm=True,
29 | text_preprocessor=None,
30 | ):
31 | """
32 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data
33 | transform_keys : keys for generating augmented views of images
34 | text_column_name : pyarrow table column name that has list of strings as elements
35 | """
36 | assert len(transform_keys) >= 1
37 | super().__init__()
38 |
39 | self.transforms = keys_to_transforms(transform_keys, size=image_size)
40 | self.clip_transform = False
41 | for transform_key in transform_keys:
42 | if 'clip' in transform_key:
43 | self.clip_transform = True
44 | break
45 | self.text_column_name = text_column_name
46 | self.names = names
47 | self.max_text_len = max_text_len
48 | self.draw_false_image = draw_false_image
49 | self.draw_false_text = draw_false_text
50 | self.image_only = image_only
51 | self.data_dir = data_dir
52 | self.disable_sep_mlm = disable_sep_mlm
53 | self.text_preprocessor = text_preprocessor
54 |
55 | if len(names) != 0:
56 | tables = [
57 | pa.ipc.RecordBatchFileReader(
58 | pa.memory_map(f"{data_dir}/{name}.arrow", "r")
59 | ).read_all()
60 | for name in names
61 | if os.path.isfile(f"{data_dir}/{name}.arrow")
62 | ]
63 | self.table_names = list()
64 | for i, name in enumerate(names):
65 | self.table_names += [name] * len(tables[i])
66 |
67 | if max_dataset_len != -1:
68 | self.table = pa.concat_tables(tables, promote=True)[
69 | :max_dataset_len]
70 | print(' truncate the dataset with length: {}'.format(max_dataset_len))
71 | else:
72 | self.table = pa.concat_tables(tables, promote=True)
73 |
74 | if text_column_name != "":
75 | self.text_column_name = text_column_name
76 | self.all_texts = self.table[text_column_name].to_pandas(
77 | ).tolist()
78 | if type(self.all_texts[0][0]) == str:
79 | if type(self.all_texts[0]) == str:
80 | self.all_texts = [
81 | [self.text_preprocessor(text)] for text in self.all_texts]
82 | else:
83 | self.all_texts = (
84 | [list(set([self.text_preprocessor(text) for text in texts]))
85 | for texts in self.all_texts]
86 | if remove_duplicate
87 | else self.all_texts
88 | )
89 | else: # snli
90 | self.all_texts = (
91 | [[t[1].strip() for t in texts]
92 | for texts in self.all_texts]
93 | )
94 | else:
95 | self.all_texts = list()
96 |
97 | self.index_mapper = dict()
98 | if text_column_name != "" and not self.image_only:
99 | j = 0
100 | for i, texts in enumerate(self.all_texts):
101 | for _j in range(len(texts)):
102 | self.index_mapper[j] = (i, _j)
103 | j += 1
104 | else:
105 | for i in range(len(self.table)):
106 | self.index_mapper[i] = (i, None)
107 | # print(' Dataset length', len(self.index_mapper))
108 |
109 | else:
110 | self.index_mapper = dict()
111 | self.all_texts = list()
112 |
113 | @property
114 | def corpus(self):
115 | return [text for texts in self.all_texts for text in texts]
116 |
117 | def __len__(self):
118 | return len(self.index_mapper)
119 |
120 | def get_raw_image(self, index, image_key="image"):
121 | index, caption_index = self.index_mapper[index]
122 | image_bytes = io.BytesIO(self.table[image_key][index].as_py())
123 | image_bytes.seek(0)
124 | if self.clip_transform:
125 | return Image.open(image_bytes).convert("RGBA")
126 | else:
127 | return Image.open(image_bytes).convert("RGB")
128 |
129 | def get_image(self, index, image_key="image"):
130 | image = self.get_raw_image(index, image_key=image_key)
131 | image_tensor = [tr(image) for tr in self.transforms]
132 | return {
133 | "image": image_tensor,
134 | "img_index": self.index_mapper[index][0],
135 | "cap_index": self.index_mapper[index][1],
136 | "raw_index": index,
137 | }
138 |
139 | def get_false_image(self, rep, image_key="image"):
140 | """get false images for image-text matching loss"""
141 | random_index = random.randint(0, len(self.index_mapper) - 1)
142 | image = self.get_raw_image(random_index, image_key=image_key)
143 | image_tensor = [tr(image) for tr in self.transforms]
144 | return {f"false_image_{rep}": image_tensor}
145 |
146 | def get_text(self, raw_index):
147 | index, caption_index = self.index_mapper[raw_index]
148 |
149 | text = self.all_texts[index][caption_index]
150 | encoding = self.tokenizer(
151 | text,
152 | padding="max_length",
153 | truncation=True,
154 | max_length=self.max_text_len,
155 | return_special_tokens_mask=True,
156 | )
157 | return {
158 | "text": (text, encoding),
159 | "img_index": index,
160 | "cap_index": caption_index,
161 | "raw_index": raw_index,
162 | }
163 |
164 | def get_false_text(self, rep):
165 | """get false text for image-text matching loss"""
166 | random_index = random.randint(0, len(self.index_mapper) - 1)
167 |
168 | index, caption_index = self.index_mapper[random_index]
169 | text = self.all_texts[index][caption_index]
170 | encoding = self.tokenizer(
171 | text,
172 | truncation=True,
173 | max_length=self.max_text_len,
174 | return_special_tokens_mask=True,
175 | )
176 | return {f"false_text_{rep}": (text, encoding)}
177 |
178 | def get_suite(self, index):
179 | result = None
180 | while result is None:
181 | try:
182 | ret = dict()
183 | ret.update(self.get_image(index))
184 | if not self.image_only:
185 | txt = self.get_text(index)
186 | ret.update(
187 | {"replica": True if txt["cap_index"] > 0 else False})
188 | ret.update(txt)
189 |
190 | for i in range(self.draw_false_image):
191 | ret.update(self.get_false_image(i))
192 | for i in range(self.draw_false_text):
193 | ret.update(self.get_false_text(i))
194 | result = True
195 | except Exception as e:
196 | print(
197 | f"Error while read file idx {index} in {self.names[0]} -> {e}")
198 | index = random.randint(0, len(self.index_mapper) - 1)
199 | return ret
200 |
201 | def collate(self, batch, mlm_collator):
202 | batch_size = len(batch)
203 | keys = set([key for b in batch for key in b.keys()])
204 | raw_dict_batch = {
205 | k: [dic[k] if k in dic else None for dic in batch] for k in keys}
206 |
207 | img_keys = [k for k in list(raw_dict_batch.keys()) if "image" in k]
208 | img_sizes = list()
209 |
210 | for img_key in img_keys:
211 | img = raw_dict_batch[img_key]
212 | img_sizes += [ii.shape for i in img if i is not None for ii in i]
213 |
214 | for size in img_sizes:
215 | assert (
216 | len(size) == 3
217 | ), f"Collate error, an image should be in shape of (3, H, W), instead of given {size}"
218 |
219 | if len(img_keys) != 0:
220 | max_height = max([i[1] for i in img_sizes])
221 | max_width = max([i[2] for i in img_sizes])
222 |
223 | for img_key in img_keys:
224 | img = raw_dict_batch[img_key]
225 | view_size = len(img[0])
226 |
227 | new_images = [
228 | torch.zeros(batch_size, 3, max_height, max_width)
229 | for _ in range(view_size)
230 | ]
231 |
232 | for bi in range(batch_size):
233 | orig_batch = img[bi]
234 | for vi in range(view_size):
235 | if orig_batch is None:
236 | new_images[vi][bi] = None
237 | else:
238 | orig = img[bi][vi]
239 | new_images[vi][bi, :, : orig.shape[1],
240 | : orig.shape[2]] = orig
241 |
242 | raw_dict_batch[img_key] = new_images
243 |
244 | txt_keys = [k for k in list(raw_dict_batch.keys()) if "text" in k]
245 |
246 | if len(txt_keys) != 0:
247 | texts = [[d[0] for d in raw_dict_batch[txt_key]]
248 | for txt_key in txt_keys]
249 | encodings = [[d[1] for d in raw_dict_batch[txt_key]]
250 | for txt_key in txt_keys]
251 | flatten_encodings = [e for encoding in encodings for e in encoding]
252 | flatten_mlms = mlm_collator['mlm_collator'](flatten_encodings)
253 | is_sep_mlm = type(
254 | flatten_mlms) == list and not self.disable_sep_mlm
255 | flatten_mlms_all = flatten_mlms if type(
256 | flatten_mlms) == list else [flatten_mlms]
257 |
258 | dict_batch_sep_mlm = {'batch': []}
259 | for flatten_mlms in flatten_mlms_all:
260 | dict_batch = copy.deepcopy(raw_dict_batch)
261 | for i, txt_key in enumerate(txt_keys):
262 | texts, encodings = (
263 | [d[0] for d in dict_batch[txt_key]],
264 | [d[1] for d in dict_batch[txt_key]],
265 | )
266 |
267 | mlm_ids, mlm_labels = (
268 | flatten_mlms["input_ids"][batch_size *
269 | (i): batch_size * (i + 1)],
270 | flatten_mlms["labels"][batch_size *
271 | (i): batch_size * (i + 1)],
272 | )
273 |
274 | input_ids = torch.zeros_like(mlm_ids)
275 | attention_mask = torch.zeros_like(mlm_ids)
276 | for _i, encoding in enumerate(encodings):
277 | _input_ids, _attention_mask = (
278 | torch.tensor(encoding["input_ids"]),
279 | torch.tensor(encoding["attention_mask"]),
280 | )
281 | input_ids[_i, : len(_input_ids)] = _input_ids
282 | attention_mask[_i, : len(
283 | _attention_mask)] = _attention_mask
284 |
285 | lm_labels = input_ids[:, 1:]
286 |
287 | if 'prefixLM_collator' in mlm_collator:
288 | plm_att_mask, prefix_lm_labels = mlm_collator['prefixLM_collator'](
289 | attention_mask, input_ids)
290 | lm_labels = prefix_lm_labels[:, 1:]
291 | dict_batch[f"{txt_key}_prefixlm_masks"] = plm_att_mask
292 |
293 | dict_batch[txt_key] = texts
294 | dict_batch[f"{txt_key}_ids"] = input_ids
295 | dict_batch[f"{txt_key}_labels"] = torch.full_like(
296 | input_ids, -100)
297 | dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids
298 | dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels
299 | dict_batch[f"{txt_key}_labels_lm"] = lm_labels
300 | dict_batch[f"{txt_key}_masks"] = attention_mask
301 | dict_batch.update(self.get_flm_batch(
302 | attention_mask, input_ids, mlm_collator, txt_key))
303 |
304 | dict_batch_sep_mlm['batch'].append(dict_batch)
305 | if not is_sep_mlm:
306 | dict_batch['is_sep_mlm'] = False
307 | return dict_batch
308 | if is_sep_mlm:
309 | dict_batch_sep_mlm['is_sep_mlm'] = True
310 | return dict_batch_sep_mlm
311 | return raw_dict_batch
312 |
313 | def get_flm_batch(self, attention_mask, input_ids, mlm_collator, txt_key):
314 | dict_batch = {}
315 | all_mask_ids = attention_mask * \
316 | self.tokenizer.convert_tokens_to_ids('')
317 | text_len = attention_mask.sum(1)
318 | all_mask_ids[:, 0] = input_ids[:, 0]
319 | all_mask_ids[torch.arange(len(
320 | text_len)), text_len - 1] = input_ids[torch.arange(len(text_len)), text_len - 1]
321 | dict_batch[f"{txt_key}_all_masks_ids"] = all_mask_ids
322 | flm_random_ids, flm_masks, flm_label = mlm_collator['flm_collator'](
323 | attention_mask)
324 | dict_batch[f"{txt_key}_flm_mask_ids"] = flm_random_ids
325 | dict_batch[f"{txt_key}_flm_masks"] = flm_masks
326 | dict_batch[f"{txt_key}_flm_labels"] = flm_label
327 | return dict_batch
328 |
--------------------------------------------------------------------------------
/flm/datasets/coco_caption_karpathy_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseDataset
2 | import io
3 | from PIL import Image
4 |
5 |
6 | # COCO Caption (with Karpathy split) Dataset
7 | class CocoCaptionKarpathyDataset(BaseDataset):
8 | def __init__(self, *args, split="", **kwargs):
9 | assert split in ["train", "val", "test"]
10 | self.split = split
11 |
12 | if split == "train":
13 | names = ["coco_caption_karpathy_train",
14 | "coco_caption_karpathy_restval"]
15 | elif split == "val":
16 | names = ["coco_caption_karpathy_test"]
17 | elif split == "test":
18 | names = ["coco_caption_karpathy_test"]
19 |
20 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
21 |
22 | def __getitem__(self, index):
23 | suite = self.get_suite(index)
24 |
25 | if "test" in self.split:
26 | _index, _question_index = self.index_mapper[index]
27 | iid = self.table["image_id"][_index].as_py()
28 | iid = int(iid.split(".")[0].split("_")[-1])
29 | suite.update({"iid": iid})
30 |
31 | return suite
32 |
--------------------------------------------------------------------------------
/flm/datasets/conceptual_caption12m_dataset.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from glob import glob
3 | from .base_dataset import BaseDataset
4 | from .conceptual_caption_dataset import ConceptualCaptionDataset
5 | import io
6 | from PIL import Image
7 |
8 |
9 | # Conceptual Caption 12M Dataset
10 | class ConceptualCaption12mDataset(ConceptualCaptionDataset):
11 | def __init__(self, *args, split="", **kwargs):
12 | assert split in ["train", "val", "test"]
13 | if split == "test":
14 | split = "val"
15 |
16 | if split == "train":
17 | names = [f"conceptual_caption12M_train_{i}" for i in range(96)]
18 | elif split == "val":
19 | # names = [f"conceptual_caption_val_{i}" for i in range(1)]
20 | names = []
21 |
22 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
23 |
--------------------------------------------------------------------------------
/flm/datasets/conceptual_caption_dataset.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .base_dataset import BaseDataset
3 |
4 |
5 | # Conceptual Caption 3M Dataset
6 | class ConceptualCaptionDataset(BaseDataset):
7 | def __init__(self, *args, split="", **kwargs):
8 | assert split in ["train", "val", "test"]
9 | if split == "test":
10 | split = "val"
11 |
12 | if split == "train":
13 | names = [f"conceptual_caption_train_{i}" for i in range(29)]
14 | elif split == "val":
15 | names = []
16 |
17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
18 |
19 | def __getitem__(self, index):
20 | return self.get_suite(index)
21 |
22 | def get_text(self, raw_index):
23 | index, caption_index = self.index_mapper[raw_index]
24 |
25 | text = self.all_texts[index][caption_index]
26 | encoding = self.tokenizer(
27 | text,
28 | padding="max_length",
29 | truncation=True,
30 | max_length=self.max_text_len,
31 | return_special_tokens_mask=True,
32 | )
33 | return {
34 | "text": (text, encoding),
35 | "img_index": index,
36 | "cap_index": caption_index,
37 | "raw_index": raw_index,
38 | }
39 |
--------------------------------------------------------------------------------
/flm/datasets/f30k_caption_karpathy_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseDataset
2 |
3 |
4 | # Flickr30K Dataset
5 | class F30KCaptionKarpathyDataset(BaseDataset):
6 | def __init__(self, *args, split="", **kwargs):
7 | assert split in ["train", "val", "test"]
8 |
9 | if split == "train":
10 | names = ["f30k_caption_karpathy_train",
11 | "f30k_caption_karpathy_val"]
12 | elif split == "val":
13 | names = ["f30k_caption_karpathy_test"]
14 | elif split == "test":
15 | names = ["f30k_caption_karpathy_test"]
16 |
17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
18 |
19 | def __getitem__(self, index):
20 | return self.get_suite(index)
21 |
--------------------------------------------------------------------------------
/flm/datasets/laion100m_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_webdataset import WebDataset
2 | import io
3 | from PIL import Image
4 |
5 |
6 | # a 100M subset of Laion-400M Dataset
7 | class Laion100mDataset(WebDataset):
8 | def __init__(self, *args, split="", **kwargs):
9 | assert split in ["train", "val", "test"]
10 | self.split = split
11 |
12 | if split == 'train':
13 | location = "/group/30042/public_datasets/LAION-400M/raw/data/{00001..10689}.tar"
14 | infinite_loader = False
15 | elif split == "val":
16 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar'
17 | infinite_loader = False
18 | elif split == 'test':
19 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar'
20 | infinite_loader = False
21 | super().__init__(*args, **kwargs, infinite_loader=infinite_loader,
22 | location=location, text_column_name="caption")
23 |
--------------------------------------------------------------------------------
/flm/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_webdataset import WebDataset
2 | import io
3 | from PIL import Image
4 |
5 |
6 | # Laion-400M Dataset
7 | class LaionDataset(WebDataset):
8 | def __init__(self, *args, split="", **kwargs):
9 | assert split in ["train", "val", "test"]
10 | self.split = split
11 |
12 | if split == 'train':
13 | # location = "/group/30042/public_datasets/LAION-400M/raw/data/38872.tar"
14 | # location = "/group/30042/public_datasets/LAION-400M/raw/data/{00000..42757}.tar"
15 | location = "/group/30042/public_datasets/LAION-400M/raw/data/{00001..42757}.tar"
16 | infinite_loader = True
17 | elif split == "val":
18 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar'
19 | infinite_loader = False
20 | elif split == 'test':
21 | location = '/group/30042/public_datasets/LAION-400M/raw/data/00000.tar'
22 | infinite_loader = False
23 | super().__init__(*args, **kwargs, infinite_loader=infinite_loader,
24 | location=location, text_column_name="caption")
25 |
--------------------------------------------------------------------------------
/flm/datasets/nlvr2_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseDataset
2 | import sys
3 | import random
4 |
5 |
6 | # NLVR2 3M Dataset
7 | class NLVR2Dataset(BaseDataset):
8 | def __init__(self, *args, split="", **kwargs):
9 | assert split in ["train", "val", "test"]
10 | self.split = split
11 |
12 | if split == "train":
13 | names = ["nlvr2_train"]
14 | elif split == "val":
15 | names = ["nlvr2_dev", "nlvr2_test1"]
16 | elif split == "test":
17 | names = ["nlvr2_dev", "nlvr2_test1"]
18 |
19 | super().__init__(
20 | *args,
21 | **kwargs,
22 | names=names,
23 | text_column_name="questions",
24 | remove_duplicate=False,
25 | )
26 |
27 | def __getitem__(self, index):
28 | result = None
29 | while result is None:
30 | try:
31 | image_tensor_0 = self.get_image(
32 | index, image_key="image_0")["image"]
33 | image_tensor_1 = self.get_image(
34 | index, image_key="image_1")["image"]
35 | text = self.get_text(index)["text"]
36 | result = True
37 | except:
38 | print(
39 | f"error while read file idx {index} in {self.names[0]}",
40 | file=sys.stderr,
41 | )
42 | index = random.randint(0, len(self.index_mapper) - 1)
43 |
44 | index, question_index = self.index_mapper[index]
45 | answers = self.table["answers"][index][question_index].as_py()
46 | answers = answers == "True"
47 |
48 | return {
49 | "image_0": image_tensor_0,
50 | "image_1": image_tensor_1,
51 | "text": text,
52 | "answers": answers,
53 | "table_name": self.table_names[index],
54 | }
55 |
--------------------------------------------------------------------------------
/flm/datasets/sbu_caption_dataset.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from glob import glob
3 | from .base_dataset import BaseDataset
4 | import io
5 | from PIL import Image
6 |
7 |
8 | # SBU Caption 3M Dataset
9 | class SBUCaptionDataset(BaseDataset):
10 | def __init__(self, *args, split="", **kwargs):
11 | assert split in ["train", "val", "test"]
12 | if split == "test":
13 | split = "val"
14 |
15 | if split == "train":
16 | names = [f"sbu_{i}" for i in range(9)]
17 | elif split == "val":
18 | names = []
19 |
20 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
21 |
22 | def __getitem__(self, index):
23 | return self.get_suite(index)
24 |
--------------------------------------------------------------------------------
/flm/datasets/snli_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseDataset
2 |
3 |
4 | # SNLI 3M Dataset
5 | class SNLIDataset(BaseDataset):
6 | def __init__(self, *args, split="", **kwargs):
7 | assert split in ["train", "val", "test"]
8 | self.split = split
9 |
10 | if split == "train":
11 | names = ["snli_train"]
12 | elif split == "val":
13 | names = ["snli_dev", "snli_test"]
14 | elif split == "test":
15 | names = ["snli_dev", "snli_test"]
16 |
17 | super().__init__(
18 | *args,
19 | **kwargs,
20 | names=names,
21 | text_column_name="sentences",
22 | remove_duplicate=False,
23 | )
24 |
25 | def __getitem__(self, index):
26 | image_tensor = self.get_image(index)["image"]
27 | text = self.get_text(index)["text"]
28 |
29 | index, question_index = self.index_mapper[index]
30 |
31 | labels = self.table["labels"][index][question_index].as_py()
32 |
33 | return {
34 | "image": image_tensor,
35 | "text": text,
36 | "labels": labels,
37 | "table_name": self.table_names[index],
38 | }
39 |
--------------------------------------------------------------------------------
/flm/datasets/vg_caption_dataset.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .base_dataset import BaseDataset
3 | import io
4 | from PIL import Image
5 |
6 |
7 | # Visual Genome Dataset
8 | class VisualGenomeCaptionDataset(BaseDataset):
9 | def __init__(self, *args, split="", **kwargs):
10 | assert split in ["train", "val", "test"]
11 | if split == "test":
12 | split = "val"
13 |
14 | if split == "train":
15 | names = ["vg"]
16 | elif split == "val":
17 | names = []
18 |
19 | super().__init__(*args, **kwargs, names=names, text_column_name="caption")
20 |
21 | def __getitem__(self, index):
22 | return self.get_suite(index)
23 |
--------------------------------------------------------------------------------
/flm/datasets/vqav2_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseDataset
2 |
3 |
4 | # VQAv2 Dataset
5 | class VQAv2Dataset(BaseDataset):
6 | def __init__(self, *args, split="", **kwargs):
7 | assert split in ["train", "val", "test"]
8 | self.split = split
9 |
10 | if split == "train":
11 | names = ["vqav2_train", "vqav2_val"]
12 | elif split == "val":
13 | names = ["vqav2_val"]
14 | elif split == "test":
15 | names = ["vqav2_test"]
16 |
17 | super().__init__(
18 | *args,
19 | **kwargs,
20 | names=names,
21 | text_column_name="questions",
22 | remove_duplicate=False,
23 | )
24 |
25 | def __getitem__(self, index):
26 | image_tensor = self.get_image(index)["image"]
27 | text = self.get_text(index)["text"]
28 |
29 | index, question_index = self.index_mapper[index]
30 | qid = self.table["question_id"][index][question_index].as_py()
31 |
32 | if self.split != "test":
33 | answers = self.table["answers"][index][question_index].as_py()
34 | labels = self.table["answer_labels"][index][question_index].as_py()
35 | scores = self.table["answer_scores"][index][question_index].as_py()
36 | else:
37 | answers = list()
38 | labels = list()
39 | scores = list()
40 |
41 | return {
42 | "image": image_tensor,
43 | "text": text,
44 | "vqa_answer": answers,
45 | "vqa_labels": labels,
46 | "vqa_scores": scores,
47 | "qid": qid,
48 | }
49 |
--------------------------------------------------------------------------------
/flm/gadgets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/gadgets/__init__.py
--------------------------------------------------------------------------------
/flm/gadgets/my_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning.metrics import Metric
3 |
4 |
5 | class Accuracy(Metric):
6 | """log the accuracy metric"""
7 |
8 | def __init__(self, dist_sync_on_step=False):
9 | super().__init__(dist_sync_on_step=dist_sync_on_step)
10 | self.add_state("correct", default=torch.tensor(
11 | 0.0), dist_reduce_fx="sum")
12 | self.add_state("total", default=torch.tensor(
13 | 0.0), dist_reduce_fx="sum")
14 |
15 | def update(self, logits, target, ignore_index=-100):
16 | logits, target = (
17 | logits.detach().to(self.correct.device),
18 | target.detach().to(self.correct.device),
19 | )
20 | preds = logits.argmax(dim=-1)
21 | preds = preds[target != ignore_index]
22 | target = target[target != ignore_index]
23 | if target.numel() == 0:
24 | return 1
25 |
26 | assert preds.shape == target.shape
27 |
28 | self.correct += torch.sum(preds == target)
29 | self.total += target.numel()
30 |
31 | def compute(self):
32 | return self.correct / self.total
33 |
34 |
35 | class Scalar(Metric):
36 | def __init__(self, dist_sync_on_step=False):
37 | super().__init__(dist_sync_on_step=dist_sync_on_step)
38 | self.add_state("scalar", default=torch.tensor(
39 | 0.0), dist_reduce_fx="sum")
40 | self.add_state("total", default=torch.tensor(
41 | 0.0), dist_reduce_fx="sum")
42 |
43 | def update(self, scalar):
44 | if isinstance(scalar, torch.Tensor):
45 | scalar = scalar.detach().to(self.scalar.device)
46 | else:
47 | scalar = torch.tensor(scalar).float().to(self.scalar.device)
48 | self.scalar += scalar
49 | self.total += 1
50 |
51 | def compute(self):
52 | return self.scalar / self.total
53 |
54 |
55 | class VQAScore(Metric):
56 | """calculate and log the VQA accuracy"""
57 |
58 | def __init__(self, dist_sync_on_step=False):
59 | super().__init__(dist_sync_on_step=dist_sync_on_step)
60 | self.add_state("score", default=torch.tensor(
61 | 0.0), dist_reduce_fx="sum")
62 | self.add_state("total", default=torch.tensor(
63 | 0.0), dist_reduce_fx="sum")
64 |
65 | def update(self, logits, target):
66 | logits, target = (
67 | logits.detach().float().to(self.score.device),
68 | target.detach().float().to(self.score.device),
69 | )
70 | logits = torch.max(logits, 1)[1]
71 | one_hots = torch.zeros(*target.size()).to(target)
72 | one_hots.scatter_(1, logits.view(-1, 1), 1)
73 | scores = one_hots * target
74 |
75 | self.score += scores.sum()
76 | self.total += len(logits)
77 |
78 | def compute(self):
79 | return self.score / self.total
80 |
--------------------------------------------------------------------------------
/flm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .flm_module import FLMTransformerSS
2 |
--------------------------------------------------------------------------------
/flm/modules/clip_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # CLIP
3 | # Modified from https://github.com/openai/CLIP/blob/main/clip/model.py
4 | # Copyright (c) OpenAI
5 | # ------------------------------------------------------------------------
6 |
7 | import warnings
8 | from tqdm import tqdm
9 | import urllib
10 | import hashlib
11 | import os
12 | from collections import OrderedDict
13 | from typing import Tuple, Union
14 |
15 | import numpy as np
16 | import torch
17 | from torch import nn
18 |
19 |
20 | class LayerNorm(nn.LayerNorm):
21 | """Subclass torch's LayerNorm to handle fp16."""
22 |
23 | def forward(self, x: torch.Tensor):
24 | orig_type = x.dtype
25 | ret = super().forward(x.type(torch.float32))
26 | return ret.type(orig_type)
27 |
28 |
29 | class QuickGELU(nn.Module):
30 | def forward(self, x: torch.Tensor):
31 | return x * torch.sigmoid(1.702 * x)
32 |
33 |
34 | class ResidualAttentionBlock(nn.Module):
35 | def __init__(self, d_model: int,
36 | n_head: int,
37 | attn_mask: torch.Tensor = None):
38 | super().__init__()
39 |
40 | self.attn = nn.MultiheadAttention(d_model, n_head)
41 | self.ln_1 = LayerNorm(d_model)
42 | self.mlp = nn.Sequential(OrderedDict([
43 | ("c_fc", nn.Linear(d_model, d_model * 4)),
44 | ("gelu", QuickGELU()),
45 | ("c_proj", nn.Linear(d_model * 4, d_model))
46 | ]))
47 | self.ln_2 = LayerNorm(d_model)
48 | self.attn_mask = attn_mask
49 |
50 | def attention(self, x: torch.Tensor, x_mask: torch.Tensor):
51 | if x_mask is not None:
52 | x_mask = x_mask.to(dtype=torch.bool, device=x.device)
53 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
54 | if self.attn_mask is not None else None
55 | return self.attn(x, x, x,
56 | need_weights=False,
57 | attn_mask=self.attn_mask,
58 | key_padding_mask=x_mask)[0]
59 |
60 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
61 | x = x + self.attention(self.ln_1(x), x_mask)
62 | x = x + self.mlp(self.ln_2(x))
63 | return x
64 |
65 |
66 | class Transformer(nn.Module):
67 | def __init__(self, width: int, layers: int,
68 | heads: int, attn_mask: torch.Tensor = None):
69 | super().__init__()
70 | self.width = width
71 | self.layers = layers
72 | self.resblocks = nn.Sequential(
73 | *[ResidualAttentionBlock(width, heads, attn_mask)
74 | for _ in range(layers-1)])
75 |
76 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
77 | for block in self.resblocks:
78 | x = block(x, x_mask)
79 | return x
80 |
81 |
82 | class VisualTransformer(nn.Module):
83 | def __init__(self, input_resolution: int, patch_size: int, width: int,
84 | layers: int, heads: int, output_dim: int,
85 | resolution_after: int):
86 | super().__init__()
87 | self.input_resolution = input_resolution
88 | self.output_dim = output_dim
89 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
90 | kernel_size=patch_size, stride=patch_size,
91 | bias=False)
92 |
93 | scale = width ** -0.5
94 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
95 | self.positional_embedding = nn.Parameter(
96 | scale * torch.randn(
97 | (resolution_after // patch_size) ** 2 + 1, width))
98 | self.ln_pre = LayerNorm(width)
99 |
100 | self.transformer = Transformer(width, layers, heads)
101 | self.ln_post = LayerNorm(width)
102 |
103 | def forward(self, x: torch.Tensor, x_mask):
104 | x = self.conv1(x) # shape = [*, width, grid, grid]
105 | # shape = [*, width, grid ** 2]
106 | x = x.reshape(x.shape[0], x.shape[1], -1)
107 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
108 | t = self.class_embedding.to(
109 | x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype,
110 | device=x.device)
111 | x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width]
112 | x = x + self.positional_embedding.to(x.dtype)
113 | x = self.ln_pre(x)
114 |
115 | x = x.permute(1, 0, 2) # NLD -> LND
116 | x = self.transformer(x, x_mask)
117 | x = x.permute(1, 0, 2) # LND -> NLD
118 |
119 | x = self.ln_post(x)
120 |
121 | return x
122 |
123 |
124 | class CLIP(nn.Module):
125 | def __init__(self,
126 | embed_dim: int,
127 | # vision
128 | image_resolution: int,
129 | vision_layers: Union[Tuple[int, int, int, int], int],
130 | vision_width: int,
131 | vision_patch_size: int,
132 | # text
133 | context_length: int,
134 | vocab_size: int,
135 | transformer_width: int,
136 | transformer_heads: int,
137 | transformer_layers: int,
138 | resolution_after=224,
139 | ):
140 | super().__init__()
141 |
142 | self.context_length = context_length
143 |
144 | vision_heads = vision_width // 64
145 | self.visual = VisualTransformer(
146 | input_resolution=image_resolution,
147 | patch_size=vision_patch_size,
148 | width=vision_width,
149 | layers=vision_layers,
150 | heads=vision_heads,
151 | output_dim=embed_dim,
152 | resolution_after=resolution_after,
153 | )
154 |
155 | self.vocab_size = vocab_size
156 | self.positional_embedding = nn.Parameter(
157 | torch.empty(self.context_length, transformer_width))
158 | self.ln_final = LayerNorm(transformer_width)
159 |
160 | self.initialize_parameters()
161 |
162 | def initialize_parameters(self):
163 | nn.init.normal_(self.positional_embedding, std=0.01)
164 |
165 | proj_std = (self.visual.transformer.width ** -0.5) * \
166 | ((2 * self.visual.transformer.layers) ** -0.5)
167 | attn_std = self.visual.transformer.width ** -0.5
168 | fc_std = (2 * self.visual.transformer.width) ** -0.5
169 | for block in self.visual.transformer.resblocks:
170 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
171 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
172 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
173 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
174 |
175 | @property
176 | def dtype(self):
177 | return self.visual.conv1.weight.dtype
178 |
179 | def forward(self, image, image_mask=None):
180 | return self.visual(image.type(self.dtype), image_mask)
181 |
182 |
183 | _MODELS = {
184 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
185 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
186 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
187 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
188 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
189 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
190 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
191 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
192 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
193 | }
194 |
195 |
196 | def _download(url: str, root: str = os.path.expanduser(".cache/clip")):
197 | os.makedirs(root, exist_ok=True)
198 | filename = os.path.basename(url)
199 |
200 | expected_sha256 = url.split("/")[-2]
201 | download_target = os.path.join(root, filename)
202 |
203 | if os.path.exists(download_target) and not os.path.isfile(download_target):
204 | raise RuntimeError(
205 | f"{download_target} exists and is not a regular file")
206 |
207 | if os.path.isfile(download_target):
208 | if hashlib.sha256(
209 | open(download_target, "rb").read()).hexdigest() \
210 | == expected_sha256:
211 | return download_target
212 | else:
213 | warnings.warn(
214 | f"{download_target} exists, but the SHA256 checksum does not \
215 | match; re-downloading the file")
216 |
217 | with urllib.request.urlopen(url) as source, \
218 | open(download_target, "wb") as output:
219 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80,
220 | unit='iB', unit_scale=True) as loop:
221 | while True:
222 | buffer = source.read(8192)
223 | if not buffer:
224 | break
225 |
226 | output.write(buffer)
227 | loop.update(len(buffer))
228 |
229 | if hashlib.sha256(
230 | open(download_target, "rb").read()).hexdigest() != expected_sha256:
231 | raise RuntimeError(
232 | "Model has been downloaded \
233 | but the SHA256 checksum does not not match")
234 |
235 | return download_target
236 |
237 |
238 | def adapt_position_encoding(model, patch_size=32, after=384,
239 | suffix='visual.positional_embedding'):
240 | keys = [k for k in model if k.endswith(suffix)]
241 | assert len(keys) == 1
242 | key = keys[0]
243 | origin_pos_embed = model[key]
244 | origin_dim2 = False
245 | if len(origin_pos_embed.shape) == 2:
246 | origin_dim2 = True
247 | origin_pos_embed = origin_pos_embed.unsqueeze(0)
248 | grid_before = int(np.sqrt(origin_pos_embed.shape[1] - 1))
249 | before = int(grid_before*patch_size)
250 | assert (before % patch_size) == 0
251 | grid_after = after // patch_size
252 | assert (after % patch_size) == 0
253 | embed_dim = origin_pos_embed.shape[-1]
254 |
255 | pos_embed = origin_pos_embed[0, 1:, :].reshape(
256 | (grid_before, grid_before, embed_dim))
257 | new_size = (grid_after, grid_after)
258 | pos_embed = torch.nn.functional.interpolate(pos_embed.permute(
259 | (2, 0, 1)).unsqueeze(0), size=new_size, mode='bicubic')
260 | pos_embed = pos_embed.squeeze(0).permute(
261 | (1, 2, 0)).reshape((-1, embed_dim))
262 | pos_embed = torch.cat(
263 | (origin_pos_embed[0, 0:1, :], pos_embed), dim=0).unsqueeze(0)
264 | assert pos_embed.shape == (1, grid_after * grid_after + 1, embed_dim)
265 | if origin_dim2:
266 | assert pos_embed.shape[0] == 1
267 | pos_embed = pos_embed.squeeze(0)
268 | model[key] = pos_embed
269 | return model
270 |
271 |
272 | def build_model(name, resolution_after=224):
273 | if name in _MODELS:
274 | model_path = _download(_MODELS[name])
275 | elif os.path.isfile(name):
276 | model_path = name
277 | else:
278 | raise RuntimeError(f"Model {name} not found; \
279 | available models = {available_models()}")
280 | try:
281 | model = torch.jit.load(model_path, map_location="cpu")
282 | state_dict = None
283 | except RuntimeError:
284 | if jit:
285 | warnings.warn(
286 | f"File {model_path} is not a JIT archive. \
287 | Loading as a state dict instead")
288 | jit = False
289 | state_dict = torch.load(model_path, map_location="cpu")
290 | state_dict = state_dict or model.state_dict()
291 |
292 | vision_width = state_dict["visual.conv1.weight"].shape[0]
293 | vision_layers = len([k for k in state_dict.keys() if k.startswith(
294 | "visual.") and k.endswith(".attn.in_proj_weight")])
295 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
296 | grid_size = round(
297 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
298 | image_resolution = vision_patch_size * grid_size
299 |
300 | embed_dim = state_dict["text_projection"].shape[1]
301 | context_length = state_dict["positional_embedding"].shape[0]
302 | vocab_size = state_dict["token_embedding.weight"].shape[0]
303 | transformer_width = state_dict["ln_final.weight"].shape[0]
304 | transformer_heads = transformer_width // 64
305 | transformer_layers = len(set(
306 | k.split(".")[2] for k in state_dict
307 | if k.startswith("transformer.resblocks")))
308 |
309 | model = CLIP(
310 | embed_dim,
311 | image_resolution, vision_layers, vision_width, vision_patch_size,
312 | context_length, vocab_size, transformer_width, transformer_heads,
313 | transformer_layers, resolution_after,
314 | )
315 |
316 | for key in ["input_resolution", "context_length", "vocab_size"]:
317 | if key in state_dict:
318 | del state_dict[key]
319 |
320 | model_dict = model.state_dict()
321 | pretrained_dict = state_dict
322 | if resolution_after != image_resolution:
323 | pretrained_dict = adapt_position_encoding(
324 | pretrained_dict,
325 | after=resolution_after,
326 | patch_size=vision_patch_size)
327 | # 1. filter out unnecessary keys
328 | pretrained_dict = {k: v for k,
329 | v in pretrained_dict.items() if k in model_dict}
330 | # 2. overwrite entries in the existing state dict
331 | model_dict.update(pretrained_dict)
332 | # 3. load the new state dict
333 | model.load_state_dict(model_dict)
334 | return model
335 |
--------------------------------------------------------------------------------
/flm/modules/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | This file contains primitives for multi-gpu communication.
4 | This is useful when doing distributed training.
5 | """
6 |
7 | import functools
8 | import logging
9 | import numpy as np
10 | import pickle
11 | import torch
12 | import torch.distributed as dist
13 |
14 | import torch
15 |
16 | _LOCAL_PROCESS_GROUP = None
17 | """
18 | A torch process group which only includes processes that on the same machine as the current process.
19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py".
20 | """
21 |
22 |
23 | def get_world_size() -> int:
24 | if not dist.is_available():
25 | return 1
26 | if not dist.is_initialized():
27 | return 1
28 | return dist.get_world_size()
29 |
30 |
31 | def get_rank() -> int:
32 | if not dist.is_available():
33 | return 0
34 | if not dist.is_initialized():
35 | return 0
36 | return dist.get_rank()
37 |
38 |
39 | def get_local_rank() -> int:
40 | """
41 | Returns:
42 | The rank of the current process within the local (per-machine) process group.
43 | """
44 | if not dist.is_available():
45 | return 0
46 | if not dist.is_initialized():
47 | return 0
48 | assert _LOCAL_PROCESS_GROUP is not None
49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
50 |
51 |
52 | def get_local_size() -> int:
53 | """
54 | Returns:
55 | The size of the per-machine process group,
56 | i.e. the number of processes per machine.
57 | """
58 | if not dist.is_available():
59 | return 1
60 | if not dist.is_initialized():
61 | return 1
62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
63 |
64 |
65 | def is_main_process() -> bool:
66 | return get_rank() == 0
67 |
68 |
69 | def synchronize():
70 | """
71 | Helper function to synchronize (barrier) among all processes when
72 | using distributed training
73 | """
74 | if not dist.is_available():
75 | return
76 | if not dist.is_initialized():
77 | return
78 | world_size = dist.get_world_size()
79 | if world_size == 1:
80 | return
81 | dist.barrier()
82 |
83 |
84 | @functools.lru_cache()
85 | def _get_global_gloo_group():
86 | """
87 | Return a process group based on gloo backend, containing all the ranks
88 | The result is cached.
89 | """
90 | if dist.get_backend() == "nccl":
91 | return dist.new_group(backend="gloo")
92 | else:
93 | return dist.group.WORLD
94 |
95 |
96 | def _serialize_to_tensor(data, group):
97 | backend = dist.get_backend(group)
98 | assert backend in ["gloo", "nccl"]
99 | device = torch.device("cpu" if backend == "gloo" else "cuda")
100 |
101 | buffer = pickle.dumps(data)
102 | if len(buffer) > 1024 ** 3:
103 | logger = logging.getLogger(__name__)
104 | logger.warning(
105 | "Rank {} trying to all-gather {:.2f} GB of data on device\
106 | {}".format(get_rank(), len(buffer) / (1024 ** 3), device)
107 | )
108 | storage = torch.ByteStorage.from_buffer(buffer)
109 | tensor = torch.ByteTensor(storage).to(device=device)
110 | return tensor
111 |
112 |
113 | def _pad_to_largest_tensor(tensor, group):
114 | """
115 | Returns:
116 | list[int]: size of the tensor, on each rank
117 | Tensor: padded tensor that has the max size
118 | """
119 | world_size = dist.get_world_size(group=group)
120 | assert (
121 | world_size >= 1
122 | ), "comm.gather/all_gather must be called from ranks within the given group!"
123 | local_size = torch.tensor(
124 | [tensor.numel()], dtype=torch.int64, device=tensor.device)
125 | size_list = [
126 | torch.zeros([1], dtype=torch.int64, device=tensor.device)
127 | for _ in range(world_size)
128 | ]
129 | dist.all_gather(size_list, local_size, group=group)
130 | size_list = [int(size.item()) for size in size_list]
131 |
132 | max_size = max(size_list)
133 |
134 | # we pad the tensor because torch all_gather does not support
135 | # gathering tensors of different shapes
136 | if local_size != max_size:
137 | padding = torch.zeros(
138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device
139 | )
140 | tensor = torch.cat((tensor, padding), dim=0)
141 | return size_list, tensor
142 |
143 |
144 | def all_gather(data, group=None):
145 | """
146 | Run all_gather on arbitrary picklable data (not necessarily tensors).
147 |
148 | Args:
149 | data: any picklable object
150 | group: a torch process group. By default, will use a group which
151 | contains all ranks on gloo backend.
152 |
153 | Returns:
154 | list[data]: list of data gathered from each rank
155 | """
156 | if get_world_size() == 1:
157 | return [data]
158 | if group is None:
159 | group = _get_global_gloo_group()
160 | if dist.get_world_size(group) == 1:
161 | return [data]
162 |
163 | tensor = _serialize_to_tensor(data, group)
164 |
165 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
166 | max_size = max(size_list)
167 |
168 | # receiving Tensor from all ranks
169 | tensor_list = [
170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
171 | for _ in size_list
172 | ]
173 | dist.all_gather(tensor_list, tensor, group=group)
174 |
175 | data_list = []
176 | for size, tensor in zip(size_list, tensor_list):
177 | buffer = tensor.cpu().numpy().tobytes()[:size]
178 | data_list.append(pickle.loads(buffer))
179 |
180 | return data_list
181 |
182 |
183 | def gather(data, dst=0, group=None):
184 | """
185 | Run gather on arbitrary picklable data (not necessarily tensors).
186 |
187 | Args:
188 | data: any picklable object
189 | dst (int): destination rank
190 | group: a torch process group. By default, will use a group which
191 | contains all ranks on gloo backend.
192 |
193 | Returns:
194 | list[data]: on dst, a list of data gathered from each rank. Otherwise,
195 | an empty list.
196 | """
197 | if get_world_size() == 1:
198 | return [data]
199 | if group is None:
200 | group = _get_global_gloo_group()
201 | if dist.get_world_size(group=group) == 1:
202 | return [data]
203 | rank = dist.get_rank(group=group)
204 |
205 | tensor = _serialize_to_tensor(data, group)
206 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
207 |
208 | # receiving Tensor from all ranks
209 | if rank == dst:
210 | max_size = max(size_list)
211 | tensor_list = [
212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
213 | for _ in size_list
214 | ]
215 | dist.gather(tensor, tensor_list, dst=dst, group=group)
216 |
217 | data_list = []
218 | for size, tensor in zip(size_list, tensor_list):
219 | buffer = tensor.cpu().numpy().tobytes()[:size]
220 | data_list.append(pickle.loads(buffer))
221 | return data_list
222 | else:
223 | dist.gather(tensor, [], dst=dst, group=group)
224 | return []
225 |
226 |
227 | def shared_random_seed():
228 | """
229 | Returns:
230 | int: a random number that is the same across all workers.
231 | If workers need a shared RNG, they can use this shared seed to
232 | create one.
233 |
234 | All workers must call this function, otherwise it will deadlock.
235 | """
236 | ints = np.random.randint(2 ** 31)
237 | all_ints = all_gather(ints)
238 | return all_ints[0]
239 |
240 |
241 | def reduce_dict(input_dict, average=True):
242 | """
243 | Reduce the values in the dictionary from all processes so that process with rank
244 | 0 has the reduced results.
245 |
246 | Args:
247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
248 | average (bool): whether to do average or sum
249 |
250 | Returns:
251 | a dict with the same keys as input_dict, after reduction.
252 | """
253 | world_size = get_world_size()
254 | if world_size < 2:
255 | return input_dict
256 | with torch.no_grad():
257 | names = []
258 | values = []
259 | # sort the keys so that they are consistent across processes
260 | for k in sorted(input_dict.keys()):
261 | names.append(k)
262 | values.append(input_dict[k])
263 | values = torch.stack(values, dim=0)
264 | dist.reduce(values, dst=0)
265 | if dist.get_rank() == 0 and average:
266 | # only main process gets accumulated, so only divide by
267 | # world_size in this case
268 | values /= world_size
269 | reduced_dict = {k: v for k, v in zip(names, values)}
270 | return reduced_dict
271 |
--------------------------------------------------------------------------------
/flm/modules/flm_tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def get_corr_bi_attention_mask(mask, mask_r, span_corr_rate=0):
6 | """prepare the attention mask in reconstrctor"""
7 | bs, L, M, N = mask.shape
8 | org_bi_mask = torch.cat([mask, mask_r], dim=-1)
9 | bi_mask = org_bi_mask.detach().clone()
10 | bi_mask[:, :, torch.arange(1, N), torch.arange(1, N)] = -10000.
11 | bi_mask[:, :, torch.arange(
12 | 1, N), N + torch.arange(1, N)] = -10000. # [bs, L, L]
13 | text_len = (bi_mask != -10000.).sum(dim=3) + 1
14 | text_len[:, :, 0] = 1
15 |
16 | if span_corr_rate > 0:
17 | add_corr_rate = torch.maximum(torch.zeros_like(
18 | text_len), (text_len * span_corr_rate - 1.)/(text_len - 1 + 1e-5))
19 | mask_num = torch.distributions.Binomial(
20 | text_len.float() - 1, add_corr_rate).sample().int()
21 | start_bias = mask_num // 2 + torch.bernoulli(mask_num/2 - mask_num//2)
22 | angle = torch.arange(0, N, device=mask.device).long()
23 | start = torch.maximum(angle - start_bias.long(), 0*angle)
24 | end = torch.minimum(start + N + mask_num, start.new_tensor(2*N-1))
25 | start_step = angle[None, None].repeat(bs, L, 1) - start
26 | for i in range(torch.max(start_step[:, :, 1:])):
27 | bi_mask[torch.arange(bs).reshape(bs, 1, 1).repeat(1, L, N), torch.arange(L).reshape(1, L, 1).repeat(
28 | bs, 1, N), angle[None, None].repeat(bs, L, 1), torch.minimum(start+i, angle[None, None])] = -10000.
29 |
30 | end_step = end - angle[None, None].repeat(bs, L, 1) - N
31 | for i in range(torch.max(end_step[:, :, 1:])):
32 | bi_mask[torch.arange(bs).reshape(bs, 1, 1).repeat(1, L, N), torch.arange(L).reshape(1, L, 1).repeat(
33 | bs, 1, N), angle[None, None].repeat(bs, L, 1), torch.maximum(end-i, N + angle[None, None])] = -10000.
34 | return torch.cat([org_bi_mask[:, :, :1], bi_mask[:, :, 1:]], dim=2)
35 |
--------------------------------------------------------------------------------
/flm/modules/heads.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import pdb
5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform
6 |
7 |
8 | class Pooler(nn.Module):
9 | def __init__(self, hidden_size):
10 | super().__init__()
11 | self.dense = nn.Linear(hidden_size, hidden_size)
12 | self.activation = nn.Tanh()
13 |
14 | def forward(self, hidden_states):
15 | first_token_tensor = hidden_states[:, 0]
16 | pooled_output = self.dense(first_token_tensor)
17 | pooled_output = self.activation(pooled_output)
18 | return pooled_output
19 |
20 |
21 | class ITMHead(nn.Module):
22 | def __init__(self, hidden_size):
23 | super().__init__()
24 | self.fc = nn.Linear(hidden_size, 2)
25 |
26 | def forward(self, x):
27 | x = self.fc(x)
28 | return x
29 |
30 |
31 | class MLMHead(nn.Module):
32 | def __init__(self, config, weight=None):
33 | super().__init__()
34 | self.transform = BertPredictionHeadTransform(config)
35 | self.decoder = nn.Linear(
36 | config.hidden_size, config.vocab_size, bias=False)
37 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
38 | if weight is not None:
39 | self.decoder.weight = weight
40 |
41 | def forward(self, x):
42 | x = self.transform(x)
43 | x = self.decoder(x) + self.bias
44 | return x
45 |
--------------------------------------------------------------------------------
/flm/modules/meter_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 | from transformers.optimization import AdamW
5 | from transformers import (
6 | get_polynomial_decay_schedule_with_warmup,
7 | get_cosine_schedule_with_warmup,
8 | )
9 | from .dist_utils import all_gather
10 | from .objectives import compute_irtr_recall, compute_caption
11 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar
12 |
13 |
14 | def set_metrics(pl_module):
15 | for split in ["train", "val"]:
16 | for k, v in pl_module.hparams.config["loss_names"].items():
17 | if v <= 0:
18 | continue
19 | if k == "vqa":
20 | setattr(pl_module, f"{split}_vqa_score", VQAScore())
21 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
22 | elif k == "nlvr2":
23 | if split == "train":
24 | setattr(pl_module, f"train_{k}_accuracy", Accuracy())
25 | setattr(pl_module, f"train_{k}_loss", Scalar())
26 | else:
27 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy())
28 | setattr(pl_module, f"dev_{k}_loss", Scalar())
29 | setattr(pl_module, f"test_{k}_accuracy", Accuracy())
30 | setattr(pl_module, f"test_{k}_loss", Scalar())
31 | elif k == "irtr":
32 | setattr(pl_module, f"{split}_irtr_loss", Scalar())
33 | elif k == "mppd" or k == "mpfr":
34 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
35 | elif k == "itm":
36 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
37 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
38 | else:
39 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
40 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
41 |
42 | if 'flm' in k and pl_module.hparams.config["enable_flm_aux_lm_loss"]:
43 | setattr(pl_module, f"{split}_flma1_accuracy", Accuracy())
44 | setattr(pl_module, f"{split}_flma2_accuracy", Accuracy())
45 | setattr(pl_module, f"{split}_flma1_loss", Scalar())
46 | setattr(pl_module, f"{split}_flma2_loss", Scalar())
47 |
48 |
49 | def epoch_wrapup(pl_module):
50 | phase = "train" if pl_module.training else "val"
51 | the_metric = 0
52 | if pl_module.hparams.config["get_caption_metric"] and not pl_module.training:
53 | b4, m, c, s = compute_caption(pl_module)
54 | pl_module.logger.experiment.add_scalar(
55 | "caption/b4", b4, pl_module.global_step
56 | )
57 | pl_module.logger.experiment.add_scalar(
58 | "caption/meter", m, pl_module.global_step
59 | )
60 | pl_module.logger.experiment.add_scalar(
61 | "caption/cider", c, pl_module.global_step
62 | )
63 | pl_module.logger.experiment.add_scalar(
64 | "caption/spice", s, pl_module.global_step
65 | )
66 | the_metric += c + m
67 |
68 | # if pl_module.hparams.config["get_mlm_caption_metric"] and not pl_module.training:
69 | # b4, m, c, s = compute_mlm_caption(pl_module)
70 | # pl_module.logger.experiment.add_scalar(
71 | # "caption/b4", b4, pl_module.global_step
72 | # )
73 | # pl_module.logger.experiment.add_scalar(
74 | # "caption/meter", m, pl_module.global_step
75 | # )
76 | # pl_module.logger.experiment.add_scalar(
77 | # "caption/cider", c, pl_module.global_step
78 | # )
79 | # pl_module.logger.experiment.add_scalar(
80 | # "caption/spice", s, pl_module.global_step
81 | # )
82 | # the_metric += c + m
83 |
84 | if pl_module.hparams.config["get_recall_metric"] and not pl_module.training:
85 | (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5,
86 | tr_r10) = compute_irtr_recall(pl_module)
87 | print((ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10), pl_module.global_step)
88 | pl_module.logger.experiment.add_scalar(
89 | "recalls/ir_r1", ir_r1, pl_module.global_step
90 | )
91 | pl_module.logger.experiment.add_scalar(
92 | "recalls/ir_r5", ir_r5, pl_module.global_step
93 | )
94 | pl_module.logger.experiment.add_scalar(
95 | "recalls/ir_r10", ir_r10, pl_module.global_step
96 | )
97 | pl_module.logger.experiment.add_scalar(
98 | "recalls/tr_r1", tr_r1, pl_module.global_step
99 | )
100 | pl_module.logger.experiment.add_scalar(
101 | "recalls/tr_r5", tr_r5, pl_module.global_step
102 | )
103 | pl_module.logger.experiment.add_scalar(
104 | "recalls/tr_r10", tr_r10, pl_module.global_step
105 | )
106 | the_metric += ir_r1.item() + tr_r1.item()
107 |
108 | for loss_name, v in pl_module.hparams.config["loss_names"].items():
109 | if v <= 0:
110 | continue
111 |
112 | value = 0
113 |
114 | if loss_name == "vqa":
115 | value = getattr(pl_module, f"{phase}_{loss_name}_score").compute()
116 | pl_module.log(f"{loss_name}/{phase}/score_epoch", value)
117 | getattr(pl_module, f"{phase}_{loss_name}_score").reset()
118 | pl_module.log(
119 | f"{loss_name}/{phase}/loss_epoch",
120 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
121 | )
122 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
123 | elif loss_name == "nlvr2" or loss_name == 'snli':
124 | if phase == "train":
125 | value = getattr(
126 | pl_module, f"train_{loss_name}_accuracy").compute()
127 | pl_module.log(f"{loss_name}/train/accuracy_epoch", value)
128 | getattr(pl_module, f"train_{loss_name}_accuracy").reset()
129 | pl_module.log(
130 | f"{loss_name}/train/loss_epoch",
131 | getattr(pl_module, f"train_{loss_name}_loss").compute(),
132 | )
133 | getattr(pl_module, f"train_{loss_name}_loss").reset()
134 | else:
135 | value = getattr(
136 | pl_module, f"test_{loss_name}_accuracy").compute()
137 | pl_module.log(f"{loss_name}/test/accuracy_epoch", value)
138 | getattr(pl_module, f"test_{loss_name}_accuracy").reset()
139 | pl_module.log(
140 | f"{loss_name}/test/loss_epoch",
141 | getattr(pl_module, f"test_{loss_name}_loss").compute(),
142 | )
143 | getattr(pl_module, f"test_{loss_name}_loss").reset()
144 |
145 | value = getattr(
146 | pl_module, f"dev_{loss_name}_accuracy").compute()
147 | pl_module.log(f"{loss_name}/dev/accuracy_epoch", value)
148 | getattr(pl_module, f"dev_{loss_name}_accuracy").reset()
149 | pl_module.log(
150 | f"{loss_name}/dev/loss_epoch",
151 | getattr(pl_module, f"dev_{loss_name}_loss").compute(),
152 | )
153 | getattr(pl_module, f"dev_{loss_name}_loss").reset()
154 | elif loss_name == 'wino':
155 | if phase == 'train':
156 | pass
157 | else:
158 | value = getattr(
159 | pl_module, f"test_{loss_name}_accuracy_img").compute()
160 | value_text = getattr(
161 | pl_module, f"test_{loss_name}_accuracy_text").compute()
162 | pl_module.log(f"{loss_name}/test/accuracy_img_epoch", value)
163 | pl_module.log(
164 | f"{loss_name}/test/accuracy_text_epoch", value_text)
165 | getattr(pl_module, f"test_{loss_name}_accuracy_img").reset()
166 | getattr(pl_module, f"test_{loss_name}_accuracy_text").reset()
167 |
168 | elif loss_name == "irtr":
169 | pl_module.log(
170 | f"{loss_name}/{phase}/irtr_loss_epoch",
171 | getattr(pl_module, f"{phase}_irtr_loss").compute(),
172 | )
173 | getattr(pl_module, f"{phase}_irtr_loss").reset()
174 | elif loss_name == "mppd" or loss_name == "mpfr":
175 | pl_module.log(
176 | f"{loss_name}/{phase}/loss_epoch",
177 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
178 | )
179 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
180 | elif loss_name == "itm":
181 | value = getattr(
182 | pl_module, f"{phase}_{loss_name}_accuracy").compute()
183 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
184 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
185 | pl_module.log(
186 | f"{loss_name}/{phase}/loss_epoch",
187 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
188 | )
189 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
190 | else:
191 | value = getattr(
192 | pl_module, f"{phase}_{loss_name}_accuracy").compute()
193 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
194 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
195 | pl_module.log(
196 | f"{loss_name}/{phase}/loss_epoch",
197 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
198 | )
199 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
200 |
201 | the_metric += value
202 |
203 | pl_module.log(f"{phase}/the_metric", the_metric)
204 |
205 |
206 | def check_non_acc_grad(pl_module):
207 | if pl_module.token_type_embeddings.weight.grad is None:
208 | return True
209 | else:
210 | grad = pl_module.token_type_embeddings.weight.grad
211 | return (grad.sum() == 0).item()
212 |
213 |
214 | def set_task(pl_module):
215 | pl_module.current_tasks = [
216 | k for k, v in pl_module.hparams.config["loss_names"].items() if v > 0
217 | ]
218 | return
219 |
220 |
221 | def get_grouped_parameters(pl_module, no_decay, head_names, cross_modal_names,
222 | wd, lr, lr_mult_head, lr_mult_cross_modal):
223 | optimizer_grouped_parameters = [
224 | {
225 | "params": [
226 | p
227 | for n, p in pl_module.named_parameters()
228 | if not any(nd in n for nd in no_decay)
229 | and not any(bb in n for bb in head_names)
230 | and not any(ht in n for ht in cross_modal_names)
231 | ],
232 | "weight_decay": wd,
233 | "lr": lr,
234 | },
235 | {
236 | "params": [
237 | p
238 | for n, p in pl_module.named_parameters()
239 | if any(nd in n for nd in no_decay)
240 | and not any(bb in n for bb in head_names)
241 | and not any(ht in n for ht in cross_modal_names)
242 | ],
243 | "weight_decay": 0.0,
244 | "lr": lr,
245 | },
246 | {
247 | "params": [
248 | p
249 | for n, p in pl_module.named_parameters()
250 | if not any(nd in n for nd in no_decay)
251 | and any(bb in n for bb in head_names)
252 | and not any(ht in n for ht in cross_modal_names)
253 | ],
254 | "weight_decay": wd,
255 | "lr": lr * lr_mult_head,
256 | },
257 | {
258 | "params": [
259 | p
260 | for n, p in pl_module.named_parameters()
261 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
262 | and not any(ht in n for ht in cross_modal_names)
263 | ],
264 | "weight_decay": 0.0,
265 | "lr": lr * lr_mult_head,
266 | },
267 | {
268 | "params": [
269 | p
270 | for n, p in pl_module.named_parameters()
271 | if not any(nd in n for nd in no_decay)
272 | and not any(bb in n for bb in head_names)
273 | and any(ht in n for ht in cross_modal_names)
274 | ],
275 | "weight_decay": wd,
276 | "lr": lr * lr_mult_cross_modal,
277 | },
278 | {
279 | "params": [
280 | p
281 | for n, p in pl_module.named_parameters()
282 | if any(nd in n for nd in no_decay)
283 | and not any(bb in n for bb in head_names)
284 | and any(ht in n for ht in cross_modal_names)
285 | ],
286 | "weight_decay": 0.0,
287 | "lr": lr * lr_mult_cross_modal,
288 | },
289 | ]
290 | return optimizer_grouped_parameters
291 |
292 |
293 | def set_schedule(pl_module):
294 | lr = pl_module.hparams.config["learning_rate"]
295 | wd = pl_module.hparams.config["weight_decay"]
296 |
297 | no_decay = [
298 | "bias",
299 | "LayerNorm.bias",
300 | "LayerNorm.weight",
301 | "norm.bias",
302 | "norm.weight",
303 | "norm1.bias",
304 | "norm1.weight",
305 | "norm2.bias",
306 | "norm2.weight",
307 | ]
308 | head_names = ["vqa_classifier", "nlvr2_classifier", "mlm_score", "itm_score",
309 | "snli_classifier", "lm_score", "flm_score", "cl_image", "cl_text"]
310 | cross_modal_names = ['cross_modal', 'fusion_layers']
311 | lr_mult_head = pl_module.hparams.config["lr_mult_head"]
312 | lr_mult_cross_modal = pl_module.hparams.config["lr_mult_cross_modal"]
313 | end_lr = pl_module.hparams.config["end_lr"]
314 | decay_power = pl_module.hparams.config["decay_power"]
315 | optim_type = pl_module.hparams.config["optim_type"]
316 |
317 | optimizer_grouped_parameters = get_grouped_parameters(
318 | pl_module, no_decay, head_names, cross_modal_names, wd, lr, lr_mult_head, lr_mult_cross_modal)
319 | if optim_type == "adamw":
320 | optimizer = AdamW(
321 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
322 | )
323 | elif optim_type == "adam":
324 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
325 | elif optim_type == "sgd":
326 | optimizer = torch.optim.SGD(
327 | optimizer_grouped_parameters, lr=lr, momentum=0.9)
328 |
329 | if pl_module.trainer.max_steps is None:
330 | max_steps = (
331 | len(pl_module.trainer.datamodule.train_dataloader())
332 | * pl_module.trainer.max_epochs
333 | // pl_module.trainer.accumulate_grad_batches
334 | )
335 | else:
336 | max_steps = pl_module.trainer.max_steps
337 |
338 | warmup_steps = pl_module.hparams.config["warmup_steps"]
339 | if isinstance(pl_module.hparams.config["warmup_steps"], float):
340 | warmup_steps = int(max_steps * warmup_steps)
341 |
342 | if decay_power == "cosine":
343 | scheduler = get_cosine_schedule_with_warmup(
344 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
345 | )
346 | else:
347 | scheduler = get_polynomial_decay_schedule_with_warmup(
348 | optimizer,
349 | num_warmup_steps=warmup_steps,
350 | num_training_steps=max_steps,
351 | lr_end=end_lr,
352 | power=decay_power,
353 | )
354 |
355 | sched = {"scheduler": scheduler, "interval": "step"}
356 |
357 | return (
358 | [optimizer],
359 | [sched],
360 | )
361 |
--------------------------------------------------------------------------------
/flm/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .transform import (
2 | pixelbert_transform,
3 | pixelbert_transform_randaug,
4 | vit_transform,
5 | vit_transform_randaug,
6 | imagenet_transform,
7 | imagenet_transform_randaug,
8 | clip_transform,
9 | clip_transform_randaug,
10 | mae_transform_randaug,
11 | mae_transform,
12 | )
13 |
14 | _transforms = {
15 | "pixelbert": pixelbert_transform,
16 | "pixelbert_randaug": pixelbert_transform_randaug,
17 | "vit": vit_transform,
18 | "vit_randaug": vit_transform_randaug,
19 | "imagenet": imagenet_transform,
20 | "imagenet_randaug": imagenet_transform_randaug,
21 | "clip": clip_transform,
22 | "clip_randaug": clip_transform_randaug,
23 | 'mae_randaug': mae_transform_randaug,
24 | 'mae': mae_transform,
25 | }
26 |
27 |
28 | def keys_to_transforms(keys: list, size=224):
29 | return [_transforms[key](size=size) for key in keys]
30 |
--------------------------------------------------------------------------------
/flm/transforms/randaug.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL
6 | import PIL.ImageOps
7 | import PIL.ImageEnhance
8 | import PIL.ImageDraw
9 | import numpy as np
10 | import torch
11 | from PIL import Image
12 |
13 |
14 | def ShearX(img, v): # [-0.3, 0.3]
15 | assert -0.3 <= v <= 0.3
16 | if random.random() > 0.5:
17 | v = -v
18 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
19 |
20 |
21 | def ShearY(img, v): # [-0.3, 0.3]
22 | assert -0.3 <= v <= 0.3
23 | if random.random() > 0.5:
24 | v = -v
25 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
26 |
27 |
28 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
29 | assert -0.45 <= v <= 0.45
30 | if random.random() > 0.5:
31 | v = -v
32 | v = v * img.size[0]
33 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
34 |
35 |
36 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
37 | assert 0 <= v
38 | if random.random() > 0.5:
39 | v = -v
40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
41 |
42 |
43 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
44 | assert -0.45 <= v <= 0.45
45 | if random.random() > 0.5:
46 | v = -v
47 | v = v * img.size[1]
48 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
49 |
50 |
51 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
52 | assert 0 <= v
53 | if random.random() > 0.5:
54 | v = -v
55 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
56 |
57 |
58 | def Rotate(img, v): # [-30, 30]
59 | assert -30 <= v <= 30
60 | if random.random() > 0.5:
61 | v = -v
62 | return img.rotate(v)
63 |
64 |
65 | def AutoContrast(img, _):
66 | return PIL.ImageOps.autocontrast(img)
67 |
68 |
69 | def Invert(img, _):
70 | return PIL.ImageOps.invert(img)
71 |
72 |
73 | def Equalize(img, _):
74 | return PIL.ImageOps.equalize(img)
75 |
76 |
77 | def Flip(img, _): # not from the paper
78 | return PIL.ImageOps.mirror(img)
79 |
80 |
81 | def Solarize(img, v): # [0, 256]
82 | assert 0 <= v <= 256
83 | return PIL.ImageOps.solarize(img, v)
84 |
85 |
86 | def SolarizeAdd(img, addition=0, threshold=128):
87 | img_np = np.array(img).astype(np.int)
88 | img_np = img_np + addition
89 | img_np = np.clip(img_np, 0, 255)
90 | img_np = img_np.astype(np.uint8)
91 | img = Image.fromarray(img_np)
92 | return PIL.ImageOps.solarize(img, threshold)
93 |
94 |
95 | def Posterize(img, v): # [4, 8]
96 | v = int(v)
97 | v = max(1, v)
98 | return PIL.ImageOps.posterize(img, v)
99 |
100 |
101 | def Contrast(img, v): # [0.1,1.9]
102 | assert 0.1 <= v <= 1.9
103 | return PIL.ImageEnhance.Contrast(img).enhance(v)
104 |
105 |
106 | def Color(img, v): # [0.1,1.9]
107 | assert 0.1 <= v <= 1.9
108 | return PIL.ImageEnhance.Color(img).enhance(v)
109 |
110 |
111 | def Brightness(img, v): # [0.1,1.9]
112 | assert 0.1 <= v <= 1.9
113 | return PIL.ImageEnhance.Brightness(img).enhance(v)
114 |
115 |
116 | def Sharpness(img, v): # [0.1,1.9]
117 | assert 0.1 <= v <= 1.9
118 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
119 |
120 |
121 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
122 | assert 0.0 <= v <= 0.2
123 | if v <= 0.0:
124 | return img
125 |
126 | v = v * img.size[0]
127 | return CutoutAbs(img, v)
128 |
129 |
130 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
131 | # assert 0 <= v <= 20
132 | if v < 0:
133 | return img
134 | w, h = img.size
135 | x0 = np.random.uniform(w)
136 | y0 = np.random.uniform(h)
137 |
138 | x0 = int(max(0, x0 - v / 2.0))
139 | y0 = int(max(0, y0 - v / 2.0))
140 | x1 = min(w, x0 + v)
141 | y1 = min(h, y0 + v)
142 |
143 | xy = (x0, y0, x1, y1)
144 | color = (125, 123, 114)
145 | # color = (0, 0, 0)
146 | img = img.copy()
147 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
148 | return img
149 |
150 |
151 | def SamplePairing(imgs): # [0, 0.4]
152 | def f(img1, v):
153 | i = np.random.choice(len(imgs))
154 | img2 = PIL.Image.fromarray(imgs[i])
155 | return PIL.Image.blend(img1, img2, v)
156 |
157 | return f
158 |
159 |
160 | def Identity(img, v):
161 | return img
162 |
163 |
164 | def augment_list(): # 16 oeprations and their ranges
165 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
166 | # l = [
167 | # (Identity, 0., 1.0),
168 | # (ShearX, 0., 0.3), # 0
169 | # (ShearY, 0., 0.3), # 1
170 | # (TranslateX, 0., 0.33), # 2
171 | # (TranslateY, 0., 0.33), # 3
172 | # (Rotate, 0, 30), # 4
173 | # (AutoContrast, 0, 1), # 5
174 | # (Invert, 0, 1), # 6
175 | # (Equalize, 0, 1), # 7
176 | # (Solarize, 0, 110), # 8
177 | # (Posterize, 4, 8), # 9
178 | # # (Contrast, 0.1, 1.9), # 10
179 | # (Color, 0.1, 1.9), # 11
180 | # (Brightness, 0.1, 1.9), # 12
181 | # (Sharpness, 0.1, 1.9), # 13
182 | # # (Cutout, 0, 0.2), # 14
183 | # # (SamplePairing(imgs), 0, 0.4), # 15
184 | # ]
185 |
186 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
187 | l = [
188 | (AutoContrast, 0, 1),
189 | (Equalize, 0, 1),
190 | # (Invert, 0, 1),
191 | (Rotate, 0, 30),
192 | (Posterize, 0, 4),
193 | (Solarize, 0, 256),
194 | (SolarizeAdd, 0, 110),
195 | (Color, 0.1, 1.9),
196 | (Contrast, 0.1, 1.9),
197 | (Brightness, 0.1, 1.9),
198 | (Sharpness, 0.1, 1.9),
199 | (ShearX, 0.0, 0.3),
200 | (ShearY, 0.0, 0.3),
201 | # (CutoutAbs, 0, 40),
202 | (TranslateXabs, 0.0, 100),
203 | (TranslateYabs, 0.0, 100),
204 | ]
205 |
206 | return l
207 |
208 |
209 | class Lighting(object):
210 | """Lighting noise(AlexNet - style PCA - based noise)"""
211 |
212 | def __init__(self, alphastd, eigval, eigvec):
213 | self.alphastd = alphastd
214 | self.eigval = torch.Tensor(eigval)
215 | self.eigvec = torch.Tensor(eigvec)
216 |
217 | def __call__(self, img):
218 | if self.alphastd == 0:
219 | return img
220 |
221 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
222 | rgb = (
223 | self.eigvec.type_as(img)
224 | .clone()
225 | .mul(alpha.view(1, 3).expand(3, 3))
226 | .mul(self.eigval.view(1, 3).expand(3, 3))
227 | .sum(1)
228 | .squeeze()
229 | )
230 |
231 | return img.add(rgb.view(3, 1, 1).expand_as(img))
232 |
233 |
234 | class CutoutDefault(object):
235 | """
236 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
237 | """
238 |
239 | def __init__(self, length):
240 | self.length = length
241 |
242 | def __call__(self, img):
243 | h, w = img.size(1), img.size(2)
244 | mask = np.ones((h, w), np.float32)
245 | y = np.random.randint(h)
246 | x = np.random.randint(w)
247 |
248 | y1 = np.clip(y - self.length // 2, 0, h)
249 | y2 = np.clip(y + self.length // 2, 0, h)
250 | x1 = np.clip(x - self.length // 2, 0, w)
251 | x2 = np.clip(x + self.length // 2, 0, w)
252 |
253 | mask[y1:y2, x1:x2] = 0.0
254 | mask = torch.from_numpy(mask)
255 | mask = mask.expand_as(img)
256 | img *= mask
257 | return img
258 |
259 |
260 | class RandAugment:
261 | def __init__(self, n, m):
262 | self.n = n
263 | self.m = m # [0, 30]
264 | self.augment_list = augment_list()
265 |
266 | def __call__(self, img):
267 | ops = random.choices(self.augment_list, k=self.n)
268 | for op, minval, maxval in ops:
269 | val = (float(self.m) / 30) * float(maxval - minval) + minval
270 | img = op(img, val)
271 |
272 | return img
273 |
--------------------------------------------------------------------------------
/flm/transforms/transform.py:
--------------------------------------------------------------------------------
1 | from .utils import (
2 | inception_normalize,
3 | imagenet_normalize,
4 | MinMaxResize,
5 | )
6 | from PIL import Image
7 | from torchvision import transforms
8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
9 | from .randaug import RandAugment
10 |
11 |
12 | def pixelbert_transform(size=800):
13 | longer = int((1333 / 800) * size)
14 | return transforms.Compose(
15 | [
16 | MinMaxResize(shorter=size, longer=longer),
17 | transforms.ToTensor(),
18 | inception_normalize,
19 | ]
20 | )
21 |
22 |
23 | def pixelbert_transform_randaug(size=800):
24 | longer = int((1333 / 800) * size)
25 | trs = transforms.Compose(
26 | [
27 | MinMaxResize(shorter=size, longer=longer),
28 | transforms.ToTensor(),
29 | inception_normalize,
30 | ]
31 | )
32 | trs.transforms.insert(0, RandAugment(2, 9))
33 | return trs
34 |
35 |
36 | def imagenet_transform(size=800):
37 | return transforms.Compose(
38 | [
39 | Resize(size, interpolation=Image.BICUBIC),
40 | CenterCrop(size),
41 | transforms.ToTensor(),
42 | imagenet_normalize,
43 | ]
44 | )
45 |
46 |
47 | def imagenet_transform_randaug(size=800):
48 | trs = transforms.Compose(
49 | [
50 | Resize(size, interpolation=Image.BICUBIC),
51 | CenterCrop(size),
52 | transforms.ToTensor(),
53 | imagenet_normalize,
54 | ]
55 | )
56 | trs.transforms.insert(0, RandAugment(2, 9))
57 | return trs
58 |
59 |
60 | def vit_transform(size=800):
61 | return transforms.Compose(
62 | [
63 | Resize(size, interpolation=Image.BICUBIC),
64 | CenterCrop(size),
65 | transforms.ToTensor(),
66 | inception_normalize,
67 | ]
68 | )
69 |
70 |
71 | def vit_transform_randaug(size=800):
72 | trs = transforms.Compose(
73 | [
74 | Resize(size, interpolation=Image.BICUBIC),
75 | CenterCrop(size),
76 | transforms.ToTensor(),
77 | inception_normalize,
78 | ]
79 | )
80 | trs.transforms.insert(0, RandAugment(2, 9))
81 | return trs
82 |
83 |
84 | def clip_transform(size):
85 | return Compose([
86 | Resize(size, interpolation=Image.BICUBIC),
87 | CenterCrop(size),
88 | lambda image: image.convert("RGB"),
89 | ToTensor(),
90 | Normalize((0.48145466, 0.4578275, 0.40821073),
91 | (0.26862954, 0.26130258, 0.27577711)),
92 | ])
93 |
94 |
95 | def clip_transform_randaug(size):
96 | trs = Compose([
97 | Resize(size, interpolation=Image.BICUBIC),
98 | CenterCrop(size),
99 | lambda image: image.convert("RGB"),
100 | ToTensor(),
101 | Normalize((0.48145466, 0.4578275, 0.40821073),
102 | (0.26862954, 0.26130258, 0.27577711)),
103 | ])
104 | trs.transforms.insert(0, lambda image: image.convert('RGBA'))
105 | trs.transforms.insert(0, RandAugment(2, 9))
106 | trs.transforms.insert(0, lambda image: image.convert('RGB'))
107 | return trs
108 |
109 |
110 | def mae_transform_randaug(size):
111 | trs = Compose([
112 | transforms.RandomResizedCrop(size, scale=(
113 | 0.2, 1.0), interpolation=3), # 3 is bicubic
114 | transforms.RandomHorizontalFlip(),
115 | lambda image: image.convert("RGB"),
116 | transforms.ToTensor(),
117 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
118 | 0.229, 0.224, 0.225])
119 | ])
120 | trs.transforms.insert(0, lambda image: image.convert('RGBA'))
121 | trs.transforms.insert(0, RandAugment(2, 9))
122 | trs.transforms.insert(0, lambda image: image.convert('RGB'))
123 | return trs
124 |
125 |
126 | def mae_transform(size):
127 | trs = Compose([
128 | Resize(size, interpolation=Image.BICUBIC),
129 | CenterCrop(size),
130 | lambda image: image.convert("RGB"),
131 | ToTensor(),
132 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
133 | 0.229, 0.224, 0.225])
134 | ])
135 | return trs
136 |
--------------------------------------------------------------------------------
/flm/transforms/utils.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 |
4 |
5 | class MinMaxResize:
6 | def __init__(self, shorter=800, longer=1333):
7 | self.min = shorter
8 | self.max = longer
9 |
10 | def __call__(self, x):
11 | w, h = x.size
12 | scale = self.min / min(w, h)
13 | if h < w:
14 | newh, neww = self.min, scale * w
15 | else:
16 | newh, neww = scale * h, self.min
17 |
18 | if max(newh, neww) > self.max:
19 | scale = self.max / max(newh, neww)
20 | newh = newh * scale
21 | neww = neww * scale
22 |
23 | newh, neww = int(newh + 0.5), int(neww + 0.5)
24 | newh, neww = newh // 32 * 32, neww // 32 * 32
25 |
26 | return x.resize((neww, newh), resample=Image.BICUBIC)
27 |
28 |
29 | class UnNormalize(object):
30 | def __init__(self, mean, std):
31 | self.mean = mean
32 | self.std = std
33 |
34 | def __call__(self, tensor):
35 | """
36 | Args:
37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
38 | Returns:
39 | Tensor: Normalized image.
40 | """
41 | for t, m, s in zip(tensor, self.mean, self.std):
42 | t.mul_(s).add_(m)
43 | # The normalize code -> t.sub_(m).div_(s)
44 | return tensor
45 |
46 |
47 | # This is simple maximum entropy normalization performed in Inception paper
48 | inception_normalize = transforms.Compose(
49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
50 | )
51 |
52 | # ViT uses simple non-biased inception normalization
53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132
54 | inception_unnormalize = transforms.Compose(
55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
56 | )
57 |
58 | # ImageNet normalize
59 | imagenet_normalize = transforms.Compose(
60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406],
61 | std=[0.229, 0.224, 0.225])]
62 | )
63 |
--------------------------------------------------------------------------------
/flm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/flm/utils/__init__.py
--------------------------------------------------------------------------------
/flm/utils/find_newest_ckpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 |
5 |
6 | save_folder = sys.argv[1]
7 | exp_name = sys.argv[2]
8 | is_last = True if sys.argv[3] == 'choose_last' else False
9 |
10 | # exp_name = '37_cl_causalflm_scratch_lr5e5_nobias_t0002_NEW_GPU32'
11 | target = '{}/{}_seed*_from*/version_*/checkpoints/epoch*-step*.ckpt'.format(
12 | save_folder, exp_name)
13 | if is_last:
14 | target = '{}/{}_seed*_from*/version_*/checkpoints/last.ckpt'.format(
15 | save_folder, exp_name)
16 | out = glob.glob(target)
17 |
18 |
19 | def get_info(p):
20 | p = p.rstrip('.ckpt')
21 | version = float(p.split('/')[-3].split('_')[-1])
22 | try:
23 | epoch = float(p.split('/')[-1].split('-')[0].split('_')[1])
24 | except:
25 | epoch = None
26 | try:
27 | score = float(p.split('/')[-1].split('-')[-1].split('_')[-1])
28 | except:
29 | score = None
30 |
31 | if score is None:
32 | score = -10000.
33 |
34 | return score, epoch, version
35 |
36 |
37 | out = sorted(out, key=get_info, reverse=True)
38 |
39 | if len(out):
40 | print(out[0])
41 |
--------------------------------------------------------------------------------
/flm/utils/glossary.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | contractions = {
4 | "aint": "ain't",
5 | "arent": "aren't",
6 | "cant": "can't",
7 | "couldve": "could've",
8 | "couldnt": "couldn't",
9 | "couldn'tve": "couldn't've",
10 | "couldnt've": "couldn't've",
11 | "didnt": "didn't",
12 | "doesnt": "doesn't",
13 | "dont": "don't",
14 | "hadnt": "hadn't",
15 | "hadnt've": "hadn't've",
16 | "hadn'tve": "hadn't've",
17 | "hasnt": "hasn't",
18 | "havent": "haven't",
19 | "hed": "he'd",
20 | "hed've": "he'd've",
21 | "he'dve": "he'd've",
22 | "hes": "he's",
23 | "howd": "how'd",
24 | "howll": "how'll",
25 | "hows": "how's",
26 | "Id've": "I'd've",
27 | "I'dve": "I'd've",
28 | "Im": "I'm",
29 | "Ive": "I've",
30 | "isnt": "isn't",
31 | "itd": "it'd",
32 | "itd've": "it'd've",
33 | "it'dve": "it'd've",
34 | "itll": "it'll",
35 | "let's": "let's",
36 | "maam": "ma'am",
37 | "mightnt": "mightn't",
38 | "mightnt've": "mightn't've",
39 | "mightn'tve": "mightn't've",
40 | "mightve": "might've",
41 | "mustnt": "mustn't",
42 | "mustve": "must've",
43 | "neednt": "needn't",
44 | "notve": "not've",
45 | "oclock": "o'clock",
46 | "oughtnt": "oughtn't",
47 | "ow's'at": "'ow's'at",
48 | "'ows'at": "'ow's'at",
49 | "'ow'sat": "'ow's'at",
50 | "shant": "shan't",
51 | "shed've": "she'd've",
52 | "she'dve": "she'd've",
53 | "she's": "she's",
54 | "shouldve": "should've",
55 | "shouldnt": "shouldn't",
56 | "shouldnt've": "shouldn't've",
57 | "shouldn'tve": "shouldn't've",
58 | "somebody'd": "somebodyd",
59 | "somebodyd've": "somebody'd've",
60 | "somebody'dve": "somebody'd've",
61 | "somebodyll": "somebody'll",
62 | "somebodys": "somebody's",
63 | "someoned": "someone'd",
64 | "someoned've": "someone'd've",
65 | "someone'dve": "someone'd've",
66 | "someonell": "someone'll",
67 | "someones": "someone's",
68 | "somethingd": "something'd",
69 | "somethingd've": "something'd've",
70 | "something'dve": "something'd've",
71 | "somethingll": "something'll",
72 | "thats": "that's",
73 | "thered": "there'd",
74 | "thered've": "there'd've",
75 | "there'dve": "there'd've",
76 | "therere": "there're",
77 | "theres": "there's",
78 | "theyd": "they'd",
79 | "theyd've": "they'd've",
80 | "they'dve": "they'd've",
81 | "theyll": "they'll",
82 | "theyre": "they're",
83 | "theyve": "they've",
84 | "twas": "'twas",
85 | "wasnt": "wasn't",
86 | "wed've": "we'd've",
87 | "we'dve": "we'd've",
88 | "weve": "we've",
89 | "werent": "weren't",
90 | "whatll": "what'll",
91 | "whatre": "what're",
92 | "whats": "what's",
93 | "whatve": "what've",
94 | "whens": "when's",
95 | "whered": "where'd",
96 | "wheres": "where's",
97 | "whereve": "where've",
98 | "whod": "who'd",
99 | "whod've": "who'd've",
100 | "who'dve": "who'd've",
101 | "wholl": "who'll",
102 | "whos": "who's",
103 | "whove": "who've",
104 | "whyll": "why'll",
105 | "whyre": "why're",
106 | "whys": "why's",
107 | "wont": "won't",
108 | "wouldve": "would've",
109 | "wouldnt": "wouldn't",
110 | "wouldnt've": "wouldn't've",
111 | "wouldn'tve": "wouldn't've",
112 | "yall": "y'all",
113 | "yall'll": "y'all'll",
114 | "y'allll": "y'all'll",
115 | "yall'd've": "y'all'd've",
116 | "y'alld've": "y'all'd've",
117 | "y'all'dve": "y'all'd've",
118 | "youd": "you'd",
119 | "youd've": "you'd've",
120 | "you'dve": "you'd've",
121 | "youll": "you'll",
122 | "youre": "you're",
123 | "youve": "you've",
124 | }
125 |
126 | manual_map = {
127 | "none": "0",
128 | "zero": "0",
129 | "one": "1",
130 | "two": "2",
131 | "three": "3",
132 | "four": "4",
133 | "five": "5",
134 | "six": "6",
135 | "seven": "7",
136 | "eight": "8",
137 | "nine": "9",
138 | "ten": "10",
139 | }
140 | articles = ["a", "an", "the"]
141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
142 | comma_strip = re.compile("(\d)(\,)(\d)")
143 | punct = [
144 | ";",
145 | r"/",
146 | "[",
147 | "]",
148 | '"',
149 | "{",
150 | "}",
151 | "(",
152 | ")",
153 | "=",
154 | "+",
155 | "\\",
156 | "_",
157 | "-",
158 | ">",
159 | "<",
160 | "@",
161 | "`",
162 | ",",
163 | "?",
164 | "!",
165 | ]
166 |
167 |
168 | def normalize_word(token):
169 | _token = token
170 | for p in punct:
171 | if (p + " " in token or " " + p in token) or (
172 | re.search(comma_strip, token) != None
173 | ):
174 | _token = _token.replace(p, "")
175 | else:
176 | _token = _token.replace(p, " ")
177 | token = period_strip.sub("", _token, re.UNICODE)
178 |
179 | _token = []
180 | temp = token.lower().split()
181 | for word in temp:
182 | word = manual_map.setdefault(word, word)
183 | if word not in articles:
184 | _token.append(word)
185 | for i, word in enumerate(_token):
186 | if word in contractions:
187 | _token[i] = contractions[word]
188 | token = " ".join(_token)
189 | token = token.replace(",", "")
190 | return token
191 |
--------------------------------------------------------------------------------
/flm/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from flm.modules import heads, objectives, meter_utils
4 |
5 |
6 | @torch.no_grad()
7 | def adapt_vocab_size(state_dict, new_vocab_size):
8 |
9 | for name in state_dict.keys():
10 | if 'embeddings.word_embeddings.weight' in name or 'fusion_token_embedding.word_embeddings.weight' in name:
11 | expand_vocab(name, state_dict, new_vocab_size)
12 |
13 | # value = state_dict[name]
14 | # old_vocab_size, old_embed_dim = value.shape
15 | # if old_vocab_size != new_vocab_size:
16 | # assert new_vocab_size > old_vocab_size
17 | # new_embeddings = nn.Embedding(new_vocab_size, old_embed_dim)
18 | # new_embeddings.apply(objectives.init_weights)
19 | # new_embeddings.weight[:old_vocab_size] = value
20 | # print(' replace vocab size of {} from {} to {}'.format(name ,old_vocab_size, new_vocab_size))
21 | # state_dict[name] = new_embeddings.weight
22 |
23 | output_params = ['mlm_score', 'lm_score', 'lm_score_r', 'lm_score_f']
24 |
25 | for p in output_params:
26 | weight_name = p + '.decoder.weight'
27 | bias_name = p + '.bias'
28 | if weight_name in name or bias_name in name:
29 | expand_vocab(name, state_dict, new_vocab_size)
30 |
31 | return state_dict
32 |
33 |
34 | def expand_vocab(name, state_dict, new_vocab_size):
35 | value = state_dict[name]
36 | if value.shape[0] != new_vocab_size:
37 | state_dict[name] = expand_tensor(value, new_vocab_size)
38 | print(' replace vocab size of {} from {} to {}'.format(
39 | name, value.shape[0], new_vocab_size))
40 |
41 |
42 | def expand_tensor(value, new_vocab_size):
43 | if value.ndim == 1:
44 | old_vocab_size = value.shape[0]
45 | new_embeddings = torch.zeros(new_vocab_size)
46 | else:
47 | old_vocab_size, old_embed_dim = value.shape
48 | new_embeddings = torch.zeros(new_vocab_size, old_embed_dim)
49 | assert new_vocab_size > old_vocab_size
50 |
51 | new_embeddings.data.normal_(mean=0.0, std=0.02)
52 |
53 | new_embeddings[:old_vocab_size] = value
54 | return new_embeddings
55 |
--------------------------------------------------------------------------------
/flm/utils/whole_word_masking.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from dataclasses import dataclass
4 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
5 |
6 | import torch
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 | # from ..file_utils import PaddingStrategy
10 | # from ..modeling_utils import PreTrainedModel
11 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
12 |
13 | from transformers import (
14 | DataCollatorForLanguageModeling)
15 |
16 |
17 | class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
18 | """
19 | Data collator used for language modeling.
20 |
21 | - collates batches of tensors, honoring their tokenizer's pad_token
22 | - preprocesses batches for masked language modeling
23 | """
24 |
25 | def __call__(
26 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
27 | ) -> Dict[str, torch.Tensor]:
28 | if isinstance(examples[0], (dict, BatchEncoding)):
29 | input_ids = [e["input_ids"] for e in examples]
30 | else:
31 | input_ids = examples
32 | examples = [{"input_ids": e} for e in examples]
33 |
34 | batch_input = _collate_batch(input_ids, self.tokenizer)
35 |
36 | mask_labels = []
37 | for e in examples:
38 | ref_tokens = []
39 | for id in tolist(e["input_ids"]):
40 | token = self.tokenizer._convert_id_to_token(id)
41 | if id == self.tokenizer.convert_tokens_to_ids(''):
42 | token = ''
43 | if id == self.tokenizer.convert_tokens_to_ids(''):
44 | token = ''
45 | ref_tokens.append(token)
46 |
47 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
48 | if "chinese_ref" in e:
49 | ref_pos = tolist(e["chinese_ref"])
50 | len_seq = len(e["input_ids"])
51 | for i in range(len_seq):
52 | if i in ref_pos:
53 | ref_tokens[i] = "##" + ref_tokens[i]
54 | mask_labels.append(self._whole_word_mask(ref_tokens))
55 | batch_mask = _collate_batch(mask_labels, self.tokenizer)
56 | inputs, labels = self.mask_tokens(batch_input, batch_mask)
57 | return {"input_ids": inputs, "labels": labels}
58 |
59 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
60 | """
61 | Get 0/1 labels for masked tokens with whole word mask proxy
62 | """
63 |
64 | cand_indexes = []
65 |
66 | for (i, token) in enumerate(input_tokens):
67 | if token == "[CLS]" or token == "[SEP]":
68 | continue
69 |
70 | if len(cand_indexes) >= 1 and token.startswith("##"):
71 | cand_indexes[-1].append(i)
72 | else:
73 | cand_indexes.append([i])
74 |
75 | random.shuffle(cand_indexes)
76 | num_to_predict = min(max_predictions, max(
77 | 1, int(round(len(input_tokens) * self.mlm_probability))))
78 | masked_lms = []
79 | covered_indexes = set()
80 | for index_set in cand_indexes:
81 | if len(masked_lms) >= num_to_predict:
82 | break
83 | # If adding a whole-word mask would exceed the maximum number of
84 | # predictions, then just skip this candidate.
85 | if len(masked_lms) + len(index_set) > num_to_predict:
86 | continue
87 | is_any_index_covered = False
88 | for index in index_set:
89 | if index in covered_indexes:
90 | is_any_index_covered = True
91 | break
92 | if is_any_index_covered:
93 | continue
94 | for index in index_set:
95 | covered_indexes.add(index)
96 | masked_lms.append(index)
97 |
98 | assert len(covered_indexes) == len(masked_lms)
99 | mask_labels = [
100 | 1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
101 | return mask_labels
102 |
103 | def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
104 | """
105 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
106 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
107 | """
108 |
109 | if self.tokenizer.mask_token is None:
110 | raise ValueError(
111 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
112 | )
113 | labels = inputs.clone()
114 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
115 |
116 | probability_matrix = mask_labels
117 |
118 | special_tokens_mask = [
119 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
120 | ]
121 | probability_matrix.masked_fill_(torch.tensor(
122 | special_tokens_mask, dtype=torch.bool), value=0.0)
123 | if self.tokenizer._pad_token is not None:
124 | padding_mask = labels.eq(self.tokenizer.pad_token_id)
125 | probability_matrix.masked_fill_(padding_mask, value=0.0)
126 |
127 | masked_indices = probability_matrix.bool()
128 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
129 |
130 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
131 | indices_replaced = torch.bernoulli(torch.full(
132 | labels.shape, 0.8)).bool() & masked_indices
133 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
134 | self.tokenizer.mask_token)
135 |
136 | # 10% of the time, we replace masked input tokens with random word
137 | indices_random = torch.bernoulli(torch.full(
138 | labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
139 | random_words = torch.randint(
140 | len(self.tokenizer), labels.shape, dtype=torch.long)
141 | inputs[indices_random] = random_words[indices_random]
142 |
143 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
144 | return inputs, labels
145 |
146 |
147 | def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
148 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
149 | # Tensorize if necessary.
150 | if isinstance(examples[0], (list, tuple)):
151 | examples = [torch.tensor(e, dtype=torch.long) for e in examples]
152 |
153 | # Check if padding is necessary.
154 | length_of_first = examples[0].size(0)
155 | are_tensors_same_length = all(
156 | x.size(0) == length_of_first for x in examples)
157 | if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
158 | return torch.stack(examples, dim=0)
159 |
160 | # If yes, check if we have a `pad_token`.
161 | if tokenizer._pad_token is None:
162 | raise ValueError(
163 | "You are attempting to pad samples but the tokenizer you are using"
164 | f" ({tokenizer.__class__.__name__}) does not have a pad token."
165 | )
166 |
167 | # Creating the full tensor and filling it with our data.
168 | max_length = max(x.size(0) for x in examples)
169 | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
170 | max_length = ((max_length // pad_to_multiple_of) + 1) * \
171 | pad_to_multiple_of
172 | result = examples[0].new_full(
173 | [len(examples), max_length], tokenizer.pad_token_id)
174 | for i, example in enumerate(examples):
175 | if tokenizer.padding_side == "right":
176 | result[i, : example.shape[0]] = example
177 | else:
178 | result[i, -example.shape[0]:] = example
179 | return result
180 |
181 |
182 | def tolist(x: Union[List[Any], torch.Tensor]):
183 | return x.tolist() if isinstance(x, torch.Tensor) else x
184 |
--------------------------------------------------------------------------------
/flm/utils/write_coco_karpathy.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pandas as pd
4 | import pyarrow as pa
5 | import random
6 |
7 | from tqdm import tqdm
8 | from glob import glob
9 | from collections import defaultdict
10 |
11 |
12 | def path2rest(path, iid2captions, iid2split):
13 | name = path.split("/")[-1]
14 | with open(path, "rb") as fp:
15 | binary = fp.read()
16 | captions = iid2captions[name]
17 | split = iid2split[name]
18 | return [binary, captions, name, split]
19 |
20 |
21 | def make_arrow(root, dataset_root):
22 | with open(f"{root}/karpathy/dataset_coco.json", "r") as fp:
23 | captions = json.load(fp)
24 |
25 | captions = captions["images"]
26 |
27 | iid2captions = defaultdict(list)
28 | iid2split = dict()
29 |
30 | for cap in tqdm(captions):
31 | filename = cap["filename"]
32 | iid2split[filename] = cap["split"]
33 | for c in cap["sentences"]:
34 | iid2captions[filename].append(c["raw"])
35 |
36 | paths = list(glob(f"{root}/train2014/*.jpg")) + \
37 | list(glob(f"{root}/val2014/*.jpg"))
38 | random.shuffle(paths)
39 | caption_paths = [path for path in paths if path.split(
40 | "/")[-1] in iid2captions]
41 |
42 | if len(paths) == len(caption_paths):
43 | print("all images have caption annotations")
44 | else:
45 | print("not all images have caption annotations")
46 | print(
47 | len(paths), len(caption_paths), len(iid2captions),
48 | )
49 |
50 | bs = [path2rest(path, iid2captions, iid2split)
51 | for path in tqdm(caption_paths)]
52 |
53 | for split in ["train", "val", "restval", "test"]:
54 | batches = [b for b in bs if b[-1] == split]
55 |
56 | dataframe = pd.DataFrame(
57 | batches, columns=["image", "caption", "image_id", "split"],
58 | )
59 |
60 | table = pa.Table.from_pandas(dataframe)
61 | os.makedirs(dataset_root, exist_ok=True)
62 | with pa.OSFile(
63 | f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb"
64 | ) as sink:
65 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
66 | writer.write_table(table)
67 |
--------------------------------------------------------------------------------
/flm/utils/write_conceptual_caption.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import gc
5 | import random
6 | import os
7 |
8 | from tqdm import tqdm
9 | from glob import glob
10 |
11 |
12 | def path2rest(path, iid2captions):
13 | split, _, name = path.split("/")[-3:]
14 | split = split.split("_")[-1]
15 | iid = name
16 |
17 | with open(path, "rb") as fp:
18 | binary = fp.read()
19 |
20 | captions = iid2captions[iid]
21 |
22 | return [
23 | binary,
24 | captions,
25 | iid,
26 | split,
27 | ]
28 |
29 |
30 | def make_arrow(root, dataset_root):
31 | for split in ["val", "train"]:
32 | with open(f"{root}/{split}_annot.json", "r") as fp:
33 | captions = json.load(fp)
34 |
35 | iid2captions = dict()
36 | for cap in tqdm(captions):
37 | iid = cap[0].split("/")[-1]
38 | iid2captions[iid] = [cap[1]]
39 |
40 | paths = list(glob(f"{root}/images_{split}/*/*"))
41 | random.shuffle(paths)
42 | caption_paths = [path for path in paths if path.split(
43 | "/")[-1] in iid2captions]
44 | if len(paths) == len(caption_paths):
45 | print("all images have caption annotations")
46 | else:
47 | print("not all images have caption annotations")
48 | print(
49 | len(paths), len(caption_paths), len(iid2captions),
50 | )
51 | arrow_path = "{dataset_root}/conceptual_caption_{split}_{sub}.arrow"
52 | write_split(caption_paths, iid2captions,
53 | dataset_root, arrow_path, split)
54 |
55 |
56 | def write_split(caption_paths, iid2captions, dataset_root, arrow_path, split):
57 | sub_len = int(len(caption_paths) // 100000)
58 | subs = list(range(sub_len + 1))
59 | for sub in subs:
60 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000]
61 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)]
62 | dataframe = pd.DataFrame(
63 | bs, columns=["image", "caption", "image_id", "split"],
64 | )
65 |
66 | table = pa.Table.from_pandas(dataframe)
67 |
68 | with pa.OSFile(
69 | arrow_path.format(**{'dataset_root': dataset_root,
70 | 'split': split,
71 | 'sub': sub}), "wb") as sink:
72 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
73 | writer.write_table(table)
74 | del dataframe
75 | del table
76 | del bs
77 | gc.collect()
78 |
--------------------------------------------------------------------------------
/flm/utils/write_conceptual_caption12M_cloud.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import gc
5 | import random
6 | import os
7 |
8 | from tqdm import tqdm
9 | from glob import glob
10 |
11 |
12 | def path2rest(path, iid2captions, data_dir, split):
13 | # split, _, name = path.split("/")[-3:]
14 | # split = split.split("_")[-1]
15 | # iid = name
16 | iid = path
17 |
18 | with open(_get_video_path(path, data_dir, split)[0], "rb") as fp:
19 | binary = fp.read()
20 |
21 | captions = iid2captions[iid]
22 |
23 | return [
24 | binary,
25 | captions,
26 | iid,
27 | split,
28 | ]
29 |
30 |
31 | def _get_caption(sample):
32 | return sample[0]
33 |
34 |
35 | def _get_video_path(file_name, data_dir, split):
36 | # conceptual captions uses this hashing to create the filename
37 | rel_dir = '.'
38 | # if split != 'train':
39 | # rel_dir = 'validation'
40 | rel_fp = os.path.join(rel_dir, file_name)
41 | return os.path.join(data_dir, rel_fp), rel_fp
42 |
43 |
44 | def make_arrow(dataset_root, save_folder, split='train', chunk_id=0, chunk_num=1):
45 |
46 | metadata_dir = os.path.join(dataset_root, 'metadata')
47 | split_files = {
48 | 'train': 'train.tsv',
49 | 'val': 'val.tsv', # there is no test
50 | }
51 | split_folders = {'train': 'training',
52 | 'val': 'validation', # there is no tes
53 | }
54 |
55 | # for split in ["val", "train"]:
56 | if True:
57 | target_split_fp = split_files[split]
58 | metadata = pd.read_csv(os.path.join(
59 | metadata_dir, target_split_fp), sep='\t')
60 |
61 | # meta_data_path = f"{root}/metadata/cc3m_{split_files}_success_full.tsv"
62 | # metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
63 |
64 | # with open(, "r") as fp:
65 | # captions = json.load(fp)
66 |
67 | # iid2captions = dict()
68 | # for cap in tqdm(captions):
69 | # iid = cap[0].split("/")[-1]
70 | # iid2captions[iid] = [cap[1]]
71 |
72 | if True:
73 | chunk_size = metadata.shape[0] // chunk_num + 1
74 | start, end = chunk_id * chunk_size, (chunk_id + 1) * chunk_size
75 | print('chunk number: {}, current chunk_id: {}, chunk_size: {}'.format(
76 | chunk_num, chunk_id, chunk_size))
77 |
78 | iid2captions = dict()
79 | for item in tqdm(range(metadata.shape[0])):
80 | if item not in range(start, end):
81 | continue
82 | sample = metadata.iloc[item]
83 | caption = _get_caption(sample)
84 | iid = sample[1]
85 | iid2captions[iid] = caption
86 |
87 | # paths = list(glob(f"{dataset_root}/{split_folders[split]}/*"))
88 | # random.shuffle(paths)
89 |
90 | # caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
91 | caption_paths = list(iid2captions.keys())
92 | # random.shuffle(caption_paths)
93 |
94 | # if len(paths) == len(caption_paths):
95 | # print("all images have caption annotations")
96 | # else:
97 | # print("not all images have caption annotations")
98 | # print(
99 | # len(paths), len(caption_paths), len(iid2captions),
100 | # )
101 |
102 | sub_len = int(len(caption_paths) // 100000)
103 | subs = list(range(sub_len + 1))
104 | print('split number: {}, split_len: {}'.format(sub_len, 100000))
105 | for sub in tqdm(subs):
106 | if sub > 0:
107 | continue
108 | print('current split id: {}'.format(sub))
109 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000]
110 | bs = [path2rest(path, iid2captions, dataset_root, split)
111 | for path in tqdm(sub_paths)]
112 |
113 | dataframe = pd.DataFrame(
114 | bs, columns=["image", "caption", "image_id", "split"],
115 | )
116 |
117 | table = pa.Table.from_pandas(dataframe)
118 |
119 | os.makedirs(save_folder, exist_ok=True)
120 | dst_arrow_file = f"{save_folder}/conceptual_caption12M_{split}_{chunk_id}_{sub}.arrow"
121 | with pa.OSFile(
122 | dst_arrow_file, "wb"
123 | ) as sink:
124 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
125 | writer.write_table(table)
126 | del dataframe
127 | del table
128 | del bs
129 | gc.collect()
130 |
--------------------------------------------------------------------------------
/flm/utils/write_conceptual_caption_cloud.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import gc
5 | import random
6 | import os
7 |
8 | from tqdm import tqdm
9 | from glob import glob
10 |
11 |
12 | def path2rest(path, iid2captions, data_dir, split):
13 | # split, _, name = path.split("/")[-3:]
14 | # split = split.split("_")[-1]
15 | # iid = name
16 | iid = path
17 |
18 | with open(_get_video_path(path, data_dir, split)[0], "rb") as fp:
19 | binary = fp.read()
20 |
21 | captions = iid2captions[iid]
22 |
23 | return [
24 | binary,
25 | captions,
26 | iid,
27 | split,
28 | ]
29 |
30 |
31 | def _get_caption(sample):
32 | return sample[0]
33 |
34 |
35 | def _get_video_path(file_name, data_dir, split):
36 | # conceptual captions uses this hashing to create the filename
37 | rel_dir = 'training'
38 | if split != 'train':
39 | rel_dir = 'validation'
40 | rel_fp = os.path.join(rel_dir, file_name)
41 | return os.path.join(data_dir, rel_fp), rel_fp
42 |
43 |
44 | def make_arrow(dataset_root, save_folder, split='train'):
45 | metadata_dir = os.path.join(dataset_root, 'metadata')
46 | split_files = {
47 | 'train': 'cc3m_training_success_full.tsv',
48 | 'val': 'cc3m_validation_success_full.tsv', # there is no test
49 | }
50 | split_folders = {'train': 'training',
51 | 'val': 'validation', # there is no tes
52 | }
53 |
54 | # for split in ["val", "train"]:
55 | if True:
56 | target_split_fp = split_files[split]
57 | metadata = pd.read_csv(os.path.join(
58 | metadata_dir, target_split_fp), sep='\t')
59 |
60 | # meta_data_path = f"{root}/metadata/cc3m_{split_files}_success_full.tsv"
61 | # metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
62 |
63 | # with open(, "r") as fp:
64 | # captions = json.load(fp)
65 |
66 | # iid2captions = dict()
67 | # for cap in tqdm(captions):
68 | # iid = cap[0].split("/")[-1]
69 | # iid2captions[iid] = [cap[1]]
70 |
71 | iid2captions = dict()
72 | for item in range(metadata.shape[0]):
73 | sample = metadata.iloc[item]
74 | caption = _get_caption(sample)
75 | iid = sample[1]
76 | iid2captions[iid] = caption
77 |
78 | # paths = list(glob(f"{dataset_root}/{split_folders[split]}/*"))
79 | # random.shuffle(paths)
80 |
81 | # caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
82 | caption_paths = list(iid2captions.keys())
83 | random.shuffle(caption_paths)
84 |
85 | # if len(paths) == len(caption_paths):
86 | # print("all images have caption annotations")
87 | # else:
88 | # print("not all images have caption annotations")
89 | # print(
90 | # len(paths), len(caption_paths), len(iid2captions),
91 | # )
92 |
93 | sub_len = int(len(caption_paths) // 100000)
94 | subs = list(range(sub_len + 1))
95 | for sub in subs:
96 | sub_paths = caption_paths[sub * 100000: (sub + 1) * 100000]
97 | bs = [path2rest(path, iid2captions, dataset_root, split)
98 | for path in tqdm(sub_paths)]
99 |
100 | dataframe = pd.DataFrame(
101 | bs, columns=["image", "caption", "image_id", "split"],
102 | )
103 |
104 | table = pa.Table.from_pandas(dataframe)
105 |
106 | os.makedirs(save_folder, exist_ok=True)
107 | with pa.OSFile(
108 | f"{save_folder}/conceptual_caption_{split}_{sub}.arrow", "wb"
109 | ) as sink:
110 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
111 | writer.write_table(table)
112 | del dataframe
113 | del table
114 | del bs
115 | gc.collect()
116 |
--------------------------------------------------------------------------------
/flm/utils/write_f30k_karpathy.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import random
5 | import os
6 |
7 | from tqdm import tqdm
8 | from glob import glob
9 | from collections import defaultdict
10 |
11 |
12 | def path2rest(path, iid2captions, iid2split):
13 | name = path.split("/")[-1]
14 |
15 | with open(path, "rb") as fp:
16 | binary = fp.read()
17 |
18 | captions = iid2captions[name]
19 | split = iid2split[name]
20 |
21 | return [binary, captions, name, split]
22 |
23 |
24 | def make_arrow(root, dataset_root):
25 | with open(f"{root}/karpathy/dataset_flickr30k.json", "r") as fp:
26 | captions = json.load(fp)
27 |
28 | captions = captions["images"]
29 |
30 | iid2captions = defaultdict(list)
31 | iid2split = dict()
32 |
33 | for cap in tqdm(captions):
34 | filename = cap["filename"]
35 | iid2split[filename] = cap["split"]
36 | for c in cap["sentences"]:
37 | iid2captions[filename].append(c["raw"])
38 |
39 | paths = list(glob(f"{root}/flickr30k-images/*.jpg"))
40 | random.shuffle(paths)
41 | caption_paths = [path for path in paths if path.split(
42 | "/")[-1] in iid2captions]
43 |
44 | if len(paths) == len(caption_paths):
45 | print("all images have caption annotations")
46 | else:
47 | print("not all images have caption annotations")
48 | print(
49 | len(paths), len(caption_paths), len(iid2captions),
50 | )
51 |
52 | bs = [path2rest(path, iid2captions, iid2split)
53 | for path in tqdm(caption_paths)]
54 |
55 | for split in ["train", "val", "test"]:
56 | batches = [b for b in bs if b[-1] == split]
57 |
58 | dataframe = pd.DataFrame(
59 | batches, columns=["image", "caption", "image_id", "split"],
60 | )
61 |
62 | table = pa.Table.from_pandas(dataframe)
63 |
64 | os.makedirs(dataset_root, exist_ok=True)
65 | with pa.OSFile(
66 | f"{dataset_root}/f30k_caption_karpathy_{split}.arrow", "wb"
67 | ) as sink:
68 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
69 | writer.write_table(table)
70 |
--------------------------------------------------------------------------------
/flm/utils/write_nlvr2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import os
5 |
6 | from tqdm import tqdm
7 | from collections import defaultdict
8 |
9 |
10 | def process(root, iden, row):
11 | texts = [r["sentence"] for r in row]
12 | labels = [r["label"] for r in row]
13 |
14 | split = iden.split("-")[0]
15 |
16 | if iden.startswith("train"):
17 | directory = row[0]["directory"]
18 | path = f"{root}/images/train/{directory}/{iden}"
19 | else:
20 | path = f"{root}/{split}/{iden}"
21 |
22 | with open(f"{path}-img0.png", "rb") as fp:
23 | img0 = fp.read()
24 | with open(f"{path}-img1.png", "rb") as fp:
25 | img1 = fp.read()
26 |
27 | return [img0, img1, texts, labels, iden]
28 |
29 |
30 | def make_arrow(root, dataset_root):
31 | train_data = list(
32 | map(json.loads, open(f"{root}/nlvr2/data/train.json").readlines())
33 | )
34 | test1_data = list(
35 | map(json.loads, open(f"{root}/nlvr2/data/test1.json").readlines())
36 | )
37 | dev_data = list(map(json.loads, open(
38 | f"{root}/nlvr2/data/dev.json").readlines()))
39 |
40 | balanced_test1_data = list(
41 | map(
42 | json.loads,
43 | open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(),
44 | )
45 | )
46 | balanced_dev_data = list(
47 | map(
48 | json.loads,
49 | open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(),
50 | )
51 | )
52 |
53 | unbalanced_test1_data = list(
54 | map(
55 | json.loads,
56 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(),
57 | )
58 | )
59 | unbalanced_dev_data = list(
60 | map(
61 | json.loads,
62 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(),
63 | )
64 | )
65 |
66 | splits = [
67 | "train",
68 | "dev",
69 | "test1",
70 | "balanced_dev",
71 | "balanced_test1",
72 | "unbalanced_dev",
73 | "unbalanced_test1",
74 | ]
75 |
76 | datas = [
77 | train_data,
78 | dev_data,
79 | test1_data,
80 | balanced_dev_data,
81 | balanced_test1_data,
82 | unbalanced_dev_data,
83 | unbalanced_test1_data,
84 | ]
85 |
86 | annotations = dict()
87 |
88 | for split, data in zip(splits, datas):
89 | _annot = defaultdict(list)
90 | for row in tqdm(data):
91 | _annot["-".join(row["identifier"].split("-")[:-1])].append(row)
92 | annotations[split] = _annot
93 |
94 | for split in splits:
95 | bs = [
96 | process(root, iden, row) for iden, row in tqdm(annotations[split].items())
97 | ]
98 |
99 | dataframe = pd.DataFrame(
100 | bs, columns=["image_0", "image_1",
101 | "questions", "answers", "identifier"],
102 | )
103 |
104 | table = pa.Table.from_pandas(dataframe)
105 |
106 | os.makedirs(dataset_root, exist_ok=True)
107 | with pa.OSFile(f"{dataset_root}/nlvr2_{split}.arrow", "wb") as sink:
108 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
109 | writer.write_table(table)
110 |
--------------------------------------------------------------------------------
/flm/utils/write_sbu.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import gc
5 | import random
6 | import os
7 |
8 | from tqdm import tqdm
9 | from glob import glob
10 | from .write_conceptual_caption import write_split
11 |
12 |
13 | def path2rest(path, iid2captions):
14 | split, _, name = path.split("/")[-3:]
15 | split = split.split("_")[-1]
16 | iid = name
17 |
18 | with open(path, "rb") as fp:
19 | binary = fp.read()
20 |
21 | captions = iid2captions[iid]
22 |
23 | return [
24 | binary,
25 | captions,
26 | iid,
27 | split,
28 | ]
29 |
30 |
31 | def make_arrow(root, dataset_root):
32 | with open(f"{root}/annot.json", "r") as fp:
33 | captions = json.load(fp)
34 |
35 | iid2captions = dict()
36 | for cap in tqdm(captions):
37 | iid = cap[0].split("/")[-1]
38 | iid2captions[iid] = [cap[1]]
39 |
40 | paths = list(glob(f"{root}/images_train/*/*"))
41 | random.shuffle(paths)
42 | caption_paths = [path for path in paths if path.split(
43 | "/")[-1] in iid2captions]
44 | if len(paths) == len(caption_paths):
45 | print("all images have caption annotations")
46 | else:
47 | print("not all images have caption annotations")
48 | print(
49 | len(paths), len(caption_paths), len(iid2captions),
50 | )
51 |
52 | arrow_path = "{dataset_root}/sbu_{sub}.arrow"
53 | write_split(caption_paths, iid2captions,
54 | dataset_root, arrow_path, split=None)
55 |
--------------------------------------------------------------------------------
/flm/utils/write_snli.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import os
5 |
6 | from tqdm import tqdm
7 | from collections import defaultdict
8 |
9 |
10 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2}
11 |
12 |
13 | def process(root, imgid, ann):
14 | with open(f"{root}/Flickr30K/images/{imgid}.jpg", "rb") as fp:
15 | img = fp.read()
16 |
17 | sentences = ann['sentences']
18 |
19 | labels = ann['labels']
20 |
21 | return [img, sentences, labels]
22 |
23 |
24 | def make_arrow(root, dataset_root):
25 | train_data = list(
26 | map(json.loads, open(f"{root}/snli_ve_train.jsonl").readlines())
27 | )
28 | test_data = list(
29 | map(json.loads, open(f"{root}/snli_ve_test.jsonl").readlines())
30 | )
31 | dev_data = list(
32 | map(json.loads, open(f"{root}/snli_ve_dev.jsonl").readlines())
33 | )
34 |
35 | splits = [
36 | "train",
37 | "dev",
38 | "test",
39 | ]
40 |
41 | annotations = dict()
42 | annotations['train'] = train_data
43 | annotations['dev'] = dev_data
44 | annotations['test'] = test_data
45 | annots = dict()
46 | for split in splits:
47 | annots[split] = {}
48 | for line in annotations[split]:
49 | imgid = line['Flickr30K_ID']
50 | if not imgid in annots[split]:
51 | annots[split][imgid] = {}
52 | annots[split][imgid]['sentences'] = []
53 | annots[split][imgid]['labels'] = []
54 | annots[split][imgid]['sentences'].append(
55 | [line['sentence1'], line['sentence2']])
56 | annots[split][imgid]['labels'].append(label2id[line['gold_label']])
57 |
58 | for split in splits:
59 | bs = [process(root, imgid, annots[split][imgid])
60 | for imgid in tqdm(annots[split])]
61 |
62 | dataframe = pd.DataFrame(
63 | bs, columns=["image", "sentences", "labels"]
64 | )
65 |
66 | table = pa.Table.from_pandas(dataframe)
67 |
68 | os.makedirs(dataset_root, exist_ok=True)
69 | with pa.OSFile(f"{dataset_root}/snli_{split}.arrow", "wb") as sink:
70 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
71 | writer.write_table(table)
72 |
--------------------------------------------------------------------------------
/flm/utils/write_vg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import random
5 | import os
6 |
7 | from tqdm import tqdm
8 | from glob import glob
9 | from collections import defaultdict
10 |
11 |
12 | def path2rest(path, iid2captions):
13 | name = path.split("/")[-1]
14 | iid = int(name[:-4])
15 |
16 | with open(path, "rb") as fp:
17 | binary = fp.read()
18 |
19 | cdicts = iid2captions[iid]
20 | captions = [c["phrase"] for c in cdicts]
21 | widths = [c["width"] for c in cdicts]
22 | heights = [c["height"] for c in cdicts]
23 | xs = [c["x"] for c in cdicts]
24 | ys = [c["y"] for c in cdicts]
25 |
26 | return [
27 | binary,
28 | captions,
29 | widths,
30 | heights,
31 | xs,
32 | ys,
33 | str(iid),
34 | ]
35 |
36 |
37 | def make_arrow(root, dataset_root):
38 | with open(f"{root}/annotations/region_descriptions.json", "r") as fp:
39 | captions = json.load(fp)
40 |
41 | iid2captions = defaultdict(list)
42 | for cap in tqdm(captions):
43 | cap = cap["regions"]
44 | for c in cap:
45 | iid2captions[c["image_id"]].append(c)
46 |
47 | paths = list(glob(f"{root}/images/VG_100K/*.jpg")) + list(
48 | glob(f"{root}/images/VG_100K_2/*.jpg")
49 | )
50 | random.shuffle(paths)
51 | caption_paths = [
52 | path for path in paths if int(path.split("/")[-1][:-4]) in iid2captions
53 | ]
54 |
55 | if len(paths) == len(caption_paths):
56 | print("all images have caption annotations")
57 | else:
58 | print("not all images have caption annotations")
59 | print(
60 | len(paths), len(caption_paths), len(iid2captions),
61 | )
62 |
63 | bs = [path2rest(path, iid2captions) for path in tqdm(caption_paths)]
64 | dataframe = pd.DataFrame(
65 | bs, columns=["image", "caption", "width",
66 | "height", "x", "y", "image_id"],
67 | )
68 | table = pa.Table.from_pandas(dataframe)
69 |
70 | os.makedirs(dataset_root, exist_ok=True)
71 | with pa.OSFile(f"{dataset_root}/vg.arrow", "wb") as sink:
72 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
73 | writer.write_table(table)
74 |
--------------------------------------------------------------------------------
/flm/utils/write_vqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import random
5 | import os
6 |
7 | from tqdm import tqdm
8 | from glob import glob
9 | from collections import defaultdict, Counter
10 | from .glossary import normalize_word
11 |
12 |
13 | def get_score(occurences):
14 | if occurences == 0:
15 | return 0.0
16 | elif occurences == 1:
17 | return 0.3
18 | elif occurences == 2:
19 | return 0.6
20 | elif occurences == 3:
21 | return 0.9
22 | else:
23 | return 1.0
24 |
25 |
26 | def path2rest(path, split, annotations, label2ans):
27 | iid = int(path.split("/")[-1].split("_")[-1][:-4])
28 |
29 | with open(path, "rb") as fp:
30 | binary = fp.read()
31 |
32 | _annot = annotations[split][iid]
33 | _annot = list(_annot.items())
34 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot]
35 | questions = [qa[0] for qa in qas]
36 | answers = [qa[1] for qa in qas] if "test" not in split else list(list())
37 | answer_labels = (
38 | [a["labels"] for a in answers] if "test" not in split else list(list())
39 | )
40 | answer_scores = (
41 | [a["scores"] for a in answers] if "test" not in split else list(list())
42 | )
43 | answers = (
44 | [[label2ans[l] for l in al] for al in answer_labels]
45 | if "test" not in split
46 | else list(list())
47 | )
48 |
49 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split]
50 |
51 |
52 | def make_arrow(root, dataset_root):
53 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp:
54 | questions_train2014 = json.load(fp)["questions"]
55 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp:
56 | questions_val2014 = json.load(fp)["questions"]
57 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp:
58 | questions_test2015 = json.load(fp)["questions"]
59 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp:
60 | questions_test_dev2015 = json.load(fp)["questions"]
61 |
62 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp:
63 | annotations_train2014 = json.load(fp)["annotations"]
64 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp:
65 | annotations_val2014 = json.load(fp)["annotations"]
66 |
67 | annotations = dict()
68 |
69 | for split, questions in zip(
70 | ["train", "val", "test", "test-dev"],
71 | [
72 | questions_train2014,
73 | questions_val2014,
74 | questions_test2015,
75 | questions_test_dev2015,
76 | ],
77 | ):
78 | _annot = defaultdict(dict)
79 | for q in tqdm(questions):
80 | _annot[q["image_id"]][q["question_id"]] = [q["question"]]
81 |
82 | annotations[split] = _annot
83 |
84 | all_major_answers = list()
85 |
86 | for split, annots in zip(
87 | ["train", "val"], [annotations_train2014, annotations_val2014],
88 | ):
89 | _annot = annotations[split]
90 | for q in tqdm(annots):
91 | all_major_answers.append(q["multiple_choice_answer"])
92 |
93 | all_major_answers = [normalize_word(word)
94 | for word in tqdm(all_major_answers)]
95 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9}
96 | ans2label = {k: i for i, k in enumerate(counter.keys())}
97 | label2ans = list(counter.keys())
98 |
99 | for split, annots in zip(
100 | ["train", "val"], [annotations_train2014, annotations_val2014],
101 | ):
102 | _annot = annotations[split]
103 | for q in tqdm(annots):
104 | answers = q["answers"]
105 | answer_count = {}
106 | for answer in answers:
107 | answer_ = answer["answer"]
108 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
109 |
110 | labels = []
111 | scores = []
112 | for answer in answer_count:
113 | if answer not in ans2label:
114 | continue
115 | labels.append(ans2label[answer])
116 | score = get_score(answer_count[answer])
117 | scores.append(score)
118 |
119 | _annot[q["image_id"]][q["question_id"]].append(
120 | {"labels": labels, "scores": scores, }
121 | )
122 |
123 | for split in ["train", "val"]:
124 | filtered_annot = dict()
125 | for ik, iv in annotations[split].items():
126 | new_q = dict()
127 | for qk, qv in iv.items():
128 | if len(qv[1]["labels"]) != 0:
129 | new_q[qk] = qv
130 | if len(new_q) != 0:
131 | filtered_annot[ik] = new_q
132 | annotations[split] = filtered_annot
133 |
134 | for split in [
135 | "train",
136 | "val",
137 | "test",
138 | "test-dev",
139 | ]:
140 | annot = annotations[split]
141 | split_name = {
142 | "train": "train2014",
143 | "val": "val2014",
144 | "test": "test2015",
145 | "test-dev": "test2015",
146 | }[split]
147 | paths = list(glob(f"{root}/{split_name}/*.jpg"))
148 | random.shuffle(paths)
149 | annot_paths = [
150 | path
151 | for path in paths
152 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot
153 | ]
154 |
155 | if len(paths) == len(annot_paths):
156 | print("all images have caption annotations")
157 | else:
158 | print("not all images have caption annotations")
159 | print(
160 | len(paths), len(annot_paths), len(annot),
161 | )
162 |
163 | bs = [
164 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths)
165 | ]
166 |
167 | dataframe = pd.DataFrame(
168 | bs,
169 | columns=[
170 | "image",
171 | "questions",
172 | "answers",
173 | "answer_labels",
174 | "answer_scores",
175 | "image_id",
176 | "question_id",
177 | "split",
178 | ],
179 | )
180 |
181 | table = pa.Table.from_pandas(dataframe)
182 |
183 | os.makedirs(dataset_root, exist_ok=True)
184 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink:
185 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
186 | writer.write_table(table)
187 |
188 | table = pa.ipc.RecordBatchFileReader(
189 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r")
190 | ).read_all()
191 |
192 | pdtable = table.to_pandas()
193 |
194 | df1 = pdtable[:-1000]
195 | df2 = pdtable[-1000:]
196 |
197 | df1 = pa.Table.from_pandas(df1)
198 | df2 = pa.Table.from_pandas(df2)
199 |
200 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink:
201 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer:
202 | writer.write_table(df1)
203 |
204 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink:
205 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer:
206 | writer.write_table(df2)
207 |
--------------------------------------------------------------------------------
/flm/utils/write_winoground.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import os
5 |
6 | from tqdm import tqdm
7 | from collections import defaultdict
8 |
9 |
10 | def process(root, iden, row):
11 | text0 = row[0]["caption_0"]
12 | text1 = row[0]["caption_1"]
13 | img0_name = row[0]["image_0"]
14 | img1_name = row[0]["image_1"]
15 | img0_path = f"{root}/data/images/{img0_name}.png"
16 | img1_path = f"{root}/data/images/{img1_name}.png"
17 | # collapsed_tag = row[0]["collapsed_tag"]
18 | with open(img0_path, "rb") as fp:
19 | img0 = fp.read()
20 | with open(img1_path, "rb") as fp:
21 | img1 = fp.read()
22 |
23 | # texts = [r["sentence"] for r in row]
24 | # labels = [r["label"] for r in row]
25 |
26 | # split = iden.split("-")[0]
27 |
28 | # if iden.startswith("train"):
29 | # directory = row[0]["directory"]
30 | # path = f"{root}/images/train/{directory}/{iden}"
31 | # else:
32 | # path = f"{root}/{split}/{iden}"
33 |
34 | # with open(f"{path}-img0.png", "rb") as fp:
35 | # img0 = fp.read()
36 | # with open(f"{path}-img1.png", "rb") as fp:
37 | # img1 = fp.read()
38 |
39 | return [img0, img1, text0, text1, iden]
40 |
41 |
42 | def make_arrow(root, dataset_root):
43 | # train_data = list(
44 | # map(json.loads, open(f"{root}/data/examples.jsonl").readlines())
45 | # )
46 | test1_data = list(
47 | map(json.loads, open(f"{root}/data/examples.jsonl").readlines())
48 | )
49 | # dev_data = list(map(json.loads, open(f"{root}/nlvr2/data/dev.json").readlines()))
50 |
51 | # balanced_test1_data = list(
52 | # map(
53 | # json.loads,
54 | # open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(),
55 | # )
56 | # )
57 | # balanced_dev_data = list(
58 | # map(
59 | # json.loads,
60 | # open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(),
61 | # )
62 | # )
63 |
64 | # unbalanced_test1_data = list(
65 | # map(
66 | # json.loads,
67 | # open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(),
68 | # )
69 | # )
70 | # unbalanced_dev_data = list(
71 | # map(
72 | # json.loads,
73 | # open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(),
74 | # )
75 | # )
76 | splits = ['test']
77 | datas = [test1_data]
78 |
79 | # splits = [
80 | # "train",
81 | # "dev",
82 | # "test1",
83 | # "balanced_dev",
84 | # "balanced_test1",
85 | # "unbalanced_dev",
86 | # "unbalanced_test1",
87 | # ]
88 |
89 | # datas = [
90 | # train_data,
91 | # dev_data,
92 | # test1_data,
93 | # balanced_dev_data,
94 | # balanced_test1_data,
95 | # unbalanced_dev_data,
96 | # unbalanced_test1_data,
97 | # ]
98 |
99 | annotations = dict()
100 |
101 | for split, data in zip(splits, datas):
102 | _annot = defaultdict(list)
103 | for row in tqdm(data):
104 | _annot[row["id"]].append(row)
105 | annotations[split] = _annot
106 |
107 | for split in splits:
108 | bs = [
109 | process(root, iden, row) for iden, row in tqdm(annotations[split].items())
110 | ]
111 |
112 | dataframe = pd.DataFrame(
113 | bs, columns=["image_0", "image_1", "text0", "text1", "identifier"],
114 | )
115 |
116 | table = pa.Table.from_pandas(dataframe)
117 |
118 | os.makedirs(dataset_root, exist_ok=True)
119 | with pa.OSFile(f"{dataset_root}/winoground_{split}.arrow", "wb") as sink:
120 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
121 | writer.write_table(table)
122 |
123 |
124 | make_arrow('/group/30042/wybertwang/dataset/winoground',
125 | '/group/30042/wybertwang/dataset/METER_task_arrow')
126 |
--------------------------------------------------------------------------------
/imgs/LMs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/imgs/LMs.png
--------------------------------------------------------------------------------
/imgs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/FLM/bd8b19d9f3a00ac6d4e58c9766957032036bffe8/imgs/pipeline.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch_lightning==1.4.0
2 | torch==1.7.1
3 | torchvision==0.8.2
4 | transformers==4.6.0
5 | Pillow==8.1.0
6 | tqdm==4.56.0
7 | ipdb==0.13.4
8 | numpy==1.19.5
9 | einops==0.3.0
10 | pyarrow
11 | sacred==0.8.2
12 | pandas==1.1.5
13 | # timm==0.4.12
14 | timm==0.3.2
15 | ftfy
16 | pycocoevalcap
17 | pycocotools
18 | webdataset
19 | nltk
20 | huggingface_hub
21 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import torch
5 | import pytorch_lightning as pl
6 | from flm.modules import FLMTransformerSS
7 | from flm.datamodules.multitask_datamodule import MTDataModule
8 | from flm.config import ex
9 |
10 |
11 | def args_checker(config):
12 | if config['enable_flm_aux_lm_loss']:
13 | assert config['loss_names']['flm'] > 0
14 | assert config['flm_backbone']
15 | assert config['is_causal_mask']
16 | assert config["hidden_size"] == config["hidden_size_for_fusion"], \
17 | "only support hidden_size_for_fusion=hidden_size"
18 |
19 |
20 | @ex.automain
21 | def run(_config):
22 | config = copy.deepcopy(_config)
23 | args_checker(config)
24 | # print(os.environ)
25 | world_size = int(os.environ.get('WORLD_SIZE', 1))
26 | rank = int(os.environ.get('RANK', 0))
27 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
28 | nnodes = int(os.environ.get('NNODES', 1))
29 | config["world_size"] = world_size
30 | config["rank"] = rank
31 | config["nnodes"] = nnodes
32 | config["num_nodes"] = nnodes
33 | config["local_rank"] = local_rank
34 |
35 | device = torch.device(f'cuda:{local_rank}')
36 | torch.cuda.set_device(device)
37 |
38 | pl.seed_everything(config["seed"])
39 | dm = MTDataModule(config, dist=True)
40 | exp_name = f'{config["exp_name"]}'
41 |
42 | os.makedirs(config["log_dir"], exist_ok=True)
43 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
44 | dirpath=None, # use logger's path
45 | save_top_k=config["ckpt_save_top_k"],
46 | verbose=True,
47 | monitor="val/the_metric",
48 | mode="max",
49 | save_last=True,
50 | filename='epoch_{epoch:0>3d}-step_{step:0>6d}-val_score_{val/the_metric:.3f}',
51 | auto_insert_metric_name=False,
52 | )
53 |
54 | version = 0 if config['fix_exp_version'] else None
55 |
56 | logger = pl.loggers.TensorBoardLogger(
57 | config["log_dir"],
58 | name=f'{exp_name}_seed{config["seed"]}_from_{config["load_path"].split("/")[-1][:-5]}',
59 | version=version,
60 | )
61 | config['exp_path'] = logger.root_dir
62 |
63 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
64 | callbacks = [checkpoint_callback, lr_callback]
65 |
66 | num_gpus = (
67 | config["num_gpus"]
68 | if isinstance(config["num_gpus"], int)
69 | else len(config["num_gpus"])
70 | )
71 |
72 | print(config)
73 | available_batch_size = config["per_gpu_batchsize"] * \
74 | num_gpus * config["num_nodes"]
75 | grad_steps = max(config["batch_size"] // (available_batch_size), 1)
76 |
77 | max_steps = config["max_steps"] if config["max_steps"] is not None else None
78 |
79 | if local_rank == 0:
80 | # print(os.environ)
81 | print(
82 | f' Node Num: {num_gpus}, Total GPU Numbers: {num_gpus * config["num_nodes"]}')
83 | print(
84 | f' Total Batch Size: {config["batch_size"]}, \
85 | Available Batch Size: {available_batch_size}, \
86 | Per GPU Batch Size: {config["per_gpu_batchsize"]},\
87 | Grad Steps: {grad_steps}')
88 | print(f' Resume_from: {config["resume_from"]}')
89 | print(f' Load_path: {config["load_path"]}')
90 | print(' All configs: \n', json.dumps(
91 | _config, sort_keys=True, indent=4, separators=(',', ':')))
92 |
93 | model = FLMTransformerSS(config)
94 |
95 | trainer = pl.Trainer(
96 | gpus=config["num_gpus"],
97 | num_nodes=config["num_nodes"],
98 | precision=config["precision"],
99 | accelerator="ddp",
100 | benchmark=True,
101 | deterministic=True,
102 | max_epochs=config["max_epoch"] if max_steps is None else 1000,
103 | max_steps=max_steps,
104 | callbacks=callbacks,
105 | logger=logger,
106 | prepare_data_per_node=config["prepare_data_per_node"],
107 | replace_sampler_ddp=False,
108 | accumulate_grad_batches=grad_steps,
109 | log_every_n_steps=100,
110 | flush_logs_every_n_steps=100,
111 | resume_from_checkpoint=config["resume_from"],
112 | weights_summary="top",
113 | fast_dev_run=config["fast_dev_run"],
114 | val_check_interval=config["val_check_interval"],
115 | # progress_bar_refresh_rate= 5 if config['debug'] else 200,
116 | num_sanity_val_steps=config['num_sanity_val_steps'],
117 | )
118 |
119 | if not config["test_only"]:
120 | trainer.fit(model, datamodule=dm)
121 | else:
122 | trainer.test(model, datamodule=dm)
123 |
--------------------------------------------------------------------------------