├── LICENSE ├── README.md ├── data ├── README.md ├── llm_judge_result.zip ├── multimodal_debate.zip └── preprocessed_data │ ├── README.md │ └── test.jsonl ├── requirements.txt ├── run.py └── src ├── __init__.py ├── config.py ├── datamodules ├── __init__.py ├── datamodule_base.py ├── meme_datamodule.py └── multitask_datamodule.py ├── datasets ├── __init__.py ├── base_dataset.py └── meme.py ├── gadgets ├── __init__.py └── my_metrics.py ├── modules ├── __init__.py ├── dist_utils.py ├── mm_module.py ├── mm_utils.py ├── objectives.py └── t5_model.py ├── transforms ├── __init__.py ├── randaug.py ├── transform.py └── utils.py └── utils └── __init__.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExplainHM 2 | Official PyTorch implementation for the paper - **Towards Explainable Harmful Meme Detection through Multimodal Debate between Large Language Models**. 3 | 4 | (**WWW 2024**: The ACM Web Conference 2024, May 2024, Singapore.) [[`paper`](https://arxiv.org/pdf/2401.13298.pdf)] 5 | 6 | 7 | ## Install 8 | 9 | ```bash 10 | conda create -n meme python=3.8 11 | conda activate meme 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Data 16 | 17 | Please refer to [data](https://github.com/HKBUNLP/ExplainHM-WWW2024/tree/main/data). 18 | 19 | ## Training 20 | ```bash 21 | export DATA="/path/to/data/folder" 22 | export LOG="/path/to/save/ckpts/name" 23 | 24 | rm -rf $LOG 25 | mkdir $LOG 26 | 27 | CUDA_VISIBLE_DEVICES=0,1 python run.py with data_root=$DATA \ 28 | num_gpus=2 num_nodes=1 task_train per_gpu_batchsize=8 batch_size=32 \ 29 | clip32_base224 text_t5_base image_size=224 vit_randaug max_text_len=512 \ 30 | log_dir=$LOG precision=32 max_epoch=10 learning_rate=5e-4 31 | ``` 32 | 33 | ## Inference 34 | 35 | ```bash 36 | export DATA="/path/to/data/folder" 37 | export LOG="/path/to/log/folder" 38 | 39 | CUDA_VISIBLE_DEVICES=0 python run.py with data_root=$DATA \ 40 | num_gpus=1 num_nodes=1 task_train per_gpu_batchsize=32 batch_size=32 test_only=True \ 41 | clip32_base224 text_t5_base image_size=224 vit_randaug \ 42 | log_dir=$LOG precision=32 \ 43 | max_text_len=512 load_path="/path/to/label_learn.ckpt" 44 | ``` 45 | 46 | ## Citation 47 | 48 | ``` 49 | @inproceedings{lin2024explainable, 50 | title={Towards Explainable Harmful Meme Detection through Multimodal Debate between Large Language Models}, 51 | author={Hongzhan Lin and Ziyang Luo and Wei Gao and Jing Ma and Bo Wang and Ruichao Yang}, 52 | booktitle={The ACM Web Conference 2024}, 53 | year={2024}, 54 | address={Singapore}, 55 | } 56 | ``` 57 | 58 | ## Acknowledgements 59 | 60 | The code is based on [ViLT](https://github.com/dandelin/ViLT) and [METER](https://github.com/zdou0830/METER/tree/main). 61 | 62 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Multimodal Debate 2 | 3 | file: multimodal_debate.zip 4 | 5 | The format of the pickle file: {'id': {'harmful': "", 'harmless': ""}, ..., 'id': {'harmful': "", 'harmless': ""}}; where harmful means the explanation from the harmful argument and harmless means the explanation from the harmless argument. 6 | 7 | ## LLM Judge 8 | 9 | file: llm_judge_result.zip 10 | 11 | The format of the pickle file: {'id': '1 or 0', ..., 'id': '1 or 0'}; where 1 means harmful and 0 means harmless. 12 | 13 | ## Data Preprocess 14 | 15 | To generate the multimodal debate for each meme, we employed LLaVA-1.5, a widely used open-source vision LLM, specifically utilizing the “llava-13b-v1-1” version (*Note that the data is just used as an example for reference but is outdated due to the update of emerging LLMs, and it is straightforward to directly generate the new data with more advanced LLMs*). The input order of the multimodal debate could be adjusted according to the results of the LLM judge. Drawing the practice of previous work like MaskPrompt, Pro-Cap, and Mr.Harm on FHM data preprocessing, the input text is augmented with image information in the data preprocessing for a fair comparison. 16 | -------------------------------------------------------------------------------- /data/llm_judge_result.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/ExplainHM-WWW2024/5c8742286a416b687fe8ce8acba32fc026e91c3e/data/llm_judge_result.zip -------------------------------------------------------------------------------- /data/multimodal_debate.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/ExplainHM-WWW2024/5c8742286a416b687fe8ce8acba32fc026e91c3e/data/multimodal_debate.zip -------------------------------------------------------------------------------- /data/preprocessed_data/README.md: -------------------------------------------------------------------------------- 1 | The test jsonl file is an example of preprocessed data for reference. 2 | 3 | The image files can be found at [Harm-C](https://drive.google.com/file/d/1dxMrnyXcED-85HCcQiA_d5rr8acwl6lp/view?usp=sharing), [Harm-P](https://drive.google.com/file/d/1fw850yxKNqzpRpQKH88D13yfrwX1MLde/view?usp=sharing) and [FHM](https://hatefulmemeschallenge.com/#download). 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.3.2 2 | transformers==4.24.0 3 | Pillow==8.1.0 4 | tqdm==4.56.0 5 | ipdb==0.13.4 6 | numpy==1.19.5 7 | einops==0.3.0 8 | pyarrow==2.0.0 9 | sacred==0.8.2 10 | pandas==1.1.5 11 | torchmetrics==0.6.0 12 | ftfy 13 | torchvision==0.13.0 14 | jsonlines 15 | chardet 16 | torch==1.12.0 17 | sentencepiece 18 | scikit-learn -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import pytorch_lightning as pl 4 | import os 5 | os.environ["NCCL_DEBUG"] = "INFO" 6 | 7 | from src.config import ex 8 | from src.modules import MMTransformerSS 9 | from src.datamodules.multitask_datamodule import MTDataModule 10 | 11 | import resource 12 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 13 | # resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1])) 14 | 15 | @ex.automain 16 | def main(_config): 17 | _config = copy.deepcopy(_config) 18 | pl.seed_everything(_config["seed"]) 19 | 20 | dm = MTDataModule(_config, dist=True) 21 | 22 | model = MMTransformerSS(_config) 23 | exp_name = f'{_config["exp_name"]}' 24 | 25 | os.makedirs(_config["log_dir"], exist_ok=True) 26 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 27 | save_top_k=1, 28 | verbose=True, 29 | monitor="val/the_metric", 30 | mode="max", 31 | save_last=True, 32 | ) 33 | logger = pl.loggers.TensorBoardLogger( 34 | _config["log_dir"], 35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 36 | ) 37 | 38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 39 | callbacks = [checkpoint_callback, lr_callback] 40 | 41 | num_gpus = ( 42 | _config["num_gpus"] 43 | if isinstance(_config["num_gpus"], int) 44 | else len(_config["num_gpus"]) 45 | ) 46 | 47 | grad_steps = max(_config["batch_size"] // ( 48 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 49 | ), 1) 50 | 51 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 52 | 53 | trainer = pl.Trainer( 54 | gpus=_config["num_gpus"], 55 | num_nodes=_config["num_nodes"], 56 | precision=_config["precision"], 57 | accelerator="ddp", 58 | benchmark=True, 59 | deterministic=True, 60 | max_epochs=_config["max_epoch"] if max_steps is None else 1000, 61 | max_steps=max_steps, 62 | callbacks=callbacks, 63 | logger=logger, 64 | prepare_data_per_node=False, 65 | replace_sampler_ddp=False, 66 | accumulate_grad_batches=grad_steps, 67 | log_every_n_steps=10, 68 | flush_logs_every_n_steps=10, 69 | resume_from_checkpoint=_config["resume_from"], 70 | weights_summary="top", 71 | fast_dev_run=_config["fast_dev_run"], 72 | val_check_interval=_config["val_check_interval"], 73 | ) 74 | 75 | if not _config["test_only"]: 76 | trainer.fit(model, datamodule=dm) 77 | else: 78 | trainer.test(model, datamodule=dm) 79 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/ExplainHM-WWW2024/5c8742286a416b687fe8ce8acba32fc026e91c3e/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("Meme", save_git_info=False) 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "clf": 0, 9 | } 10 | ret.update(d) 11 | return ret 12 | 13 | 14 | @ex.config 15 | def config(): 16 | exp_name = "Meme" 17 | seed = 0 18 | datasets = ["meme"] 19 | loss_names = _loss_names({"clf": 1}) 20 | batch_size = 4096 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 21 | temperature = 0.05 22 | 23 | # Image setting 24 | train_transform_keys = ["vit"] 25 | val_transform_keys = ["vit"] 26 | image_size = 224 27 | patch_size = 16 28 | 29 | # Text Setting 30 | max_text_len = 40 31 | tokenizer = "t5-small" 32 | vocab_size = 32128 33 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 34 | mlm_prob = 0.3 35 | 36 | # Transformer Setting 37 | input_image_embed_size = 768 38 | input_text_embed_size = 768 39 | vit = 'google/vit-base-patch32-224-in21k' 40 | hidden_size = 768 41 | num_heads = 12 42 | mlp_ratio = 4 43 | drop_rate = 0.1 44 | 45 | # Optimizer Setting 46 | optim_type = "adamw" 47 | learning_rate = 1e-5 48 | weight_decay = 0.01 49 | decay_power = 1 50 | max_epoch = 100 51 | max_steps = 100000 52 | warmup_steps = 10000 53 | end_lr = 0 54 | 55 | # PL Trainer Setting 56 | resume_from = None 57 | fast_dev_run = False 58 | val_check_interval = 1.0 59 | test_only = False 60 | get_recall_metric = False 61 | 62 | # below params varies with the environment 63 | data_root = "" 64 | log_dir = "result" 65 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 66 | num_gpus = 8 67 | num_nodes = 1 68 | load_path = "" 69 | num_workers = 8 70 | precision = 32 71 | # resume_from = "" 72 | 73 | @ex.named_config 74 | def task_train(): 75 | exp_name = "MEME" 76 | datasets = ["meme"] 77 | loss_names = _loss_names({ 78 | "clf": 1, 79 | }) 80 | batch_size = 256 81 | temperature = 0.05 82 | max_epoch = 30 83 | max_steps = None 84 | warmup_steps = 0.1 85 | whole_word_masking = False 86 | 87 | vocab_size = 32128 88 | max_text_len = 40 89 | image_size = 224 90 | tokenizer = "bert-base-uncased" 91 | train_transform_keys = ["vit"] 92 | val_transform_keys = ["vit"] 93 | learning_rate = 5e-5 94 | val_check_interval = 1.0 95 | hidden_size = 768 96 | num_heads = 12 97 | 98 | 99 | # visual encoder 100 | @ex.named_config 101 | def vit32_base224(): 102 | vit = "google/vit-base-patch32-224-in21k" 103 | patch_size = 32 104 | image_size = 224 105 | train_transform_keys = ["vit"] 106 | val_transform_keys = ["vit"] 107 | input_image_embed_size = 768 108 | 109 | @ex.named_config 110 | def vit16_base224(): 111 | vit = "google/vit-base-patch16-224-in21k" 112 | patch_size = 16 113 | image_size = 224 114 | train_transform_keys = ["vit"] 115 | val_transform_keys = ["vit"] 116 | input_image_embed_size = 768 117 | 118 | @ex.named_config 119 | def vit16_base384(): 120 | vit = "google/vit-base-patch16-384" 121 | patch_size = 16 122 | image_size = 384 123 | train_transform_keys = ["vit"] 124 | val_transform_keys = ["vit"] 125 | input_image_embed_size = 768 126 | 127 | @ex.named_config 128 | def clip32_base224(): 129 | vit = "openai/clip-vit-base-patch32" 130 | patch_size = 32 131 | image_size = 224 132 | train_transform_keys = ["vit"] 133 | val_transform_keys = ["vit"] 134 | input_image_embed_size = 768 135 | 136 | @ex.named_config 137 | def clip16_base224(): 138 | vit = "openai/clip-vit-base-patch16" 139 | patch_size = 16 140 | image_size = 224 141 | train_transform_keys = ["vit"] 142 | val_transform_keys = ["vit"] 143 | input_image_embed_size = 768 144 | 145 | # text encoder 146 | @ex.named_config 147 | def text_bert(): 148 | tokenizer = "bert-base-uncased" 149 | vocab_size = 30522 150 | input_text_embed_size = 768 151 | 152 | @ex.named_config 153 | def text_roberta_large(): 154 | tokenizer = "roberta-large" 155 | vocab_size = 50265 156 | input_text_embed_size = 1024 157 | 158 | # text encoder 159 | @ex.named_config 160 | def text_t5_small(): 161 | tokenizer = "google/flan-t5-small" 162 | vocab_size = 32128 163 | input_text_embed_size = 512 164 | 165 | @ex.named_config 166 | def text_t5_base(): 167 | tokenizer = "google/flan-t5-base" 168 | vocab_size = 32128 169 | input_text_embed_size = 768 170 | 171 | @ex.named_config 172 | def text_t5_large(): 173 | tokenizer = "google/flan-t5-large" 174 | vocab_size = 32128 175 | input_text_embed_size = 1024 176 | 177 | # random augmentation 178 | @ex.named_config 179 | def vit_randaug(): 180 | train_transform_keys = ["vit_randaug"] 181 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .meme_datamodule import MemeDataModule 2 | 3 | _datamodules = { 4 | "meme": MemeDataModule, 5 | } 6 | -------------------------------------------------------------------------------- /src/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | AutoTokenizer, 9 | ) 10 | 11 | 12 | def get_pretrained_tokenizer(from_pretrained): 13 | if torch.distributed.is_initialized(): 14 | if torch.distributed.get_rank() == 0: 15 | AutoTokenizer.from_pretrained(from_pretrained) 16 | torch.distributed.barrier() 17 | 18 | return AutoTokenizer.from_pretrained(from_pretrained) 19 | 20 | 21 | class BaseDataModule(LightningDataModule): 22 | def __init__(self, _config): 23 | super().__init__() 24 | 25 | self.data_dir = _config["data_root"] 26 | 27 | self.num_workers = _config["num_workers"] 28 | self.batch_size = _config["per_gpu_batchsize"] 29 | self.eval_batch_size = self.batch_size 30 | 31 | self.image_size = _config["image_size"] 32 | self.patch_size = _config["patch_size"] 33 | self.max_text_len = _config["max_text_len"] 34 | 35 | self.train_transform_keys = ( 36 | ["default_train"] 37 | if len(_config["train_transform_keys"]) == 0 38 | else _config["train_transform_keys"] 39 | ) 40 | 41 | self.val_transform_keys = ( 42 | ["default_val"] 43 | if len(_config["val_transform_keys"]) == 0 44 | else _config["val_transform_keys"] 45 | ) 46 | 47 | tokenizer = _config["tokenizer"] 48 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 49 | self.vocab_size = self.tokenizer.vocab_size 50 | 51 | collator = ( 52 | DataCollatorForWholeWordMask 53 | if _config["whole_word_masking"] 54 | else DataCollatorForLanguageModeling 55 | ) 56 | 57 | self.mlm_collator = collator( 58 | tokenizer=self.tokenizer, mlm=False, mlm_probability=_config["mlm_prob"] 59 | ) 60 | self.setup_flag = False 61 | 62 | @property 63 | def dataset_cls(self): 64 | raise NotImplementedError("return tuple of dataset class") 65 | 66 | @property 67 | def dataset_name(self): 68 | raise NotImplementedError("return name of dataset") 69 | 70 | def set_train_dataset(self): 71 | self.train_dataset = self.dataset_cls( 72 | data_dir=self.data_dir, 73 | transform_keys=self.train_transform_keys, 74 | split="train", 75 | image_size=self.image_size, 76 | patch_size=self.patch_size, 77 | max_text_len=self.max_text_len, 78 | tokenizer=self.tokenizer, 79 | ) 80 | 81 | def set_val_dataset(self): 82 | self.val_dataset = self.dataset_cls( 83 | data_dir=self.data_dir, 84 | transform_keys=self.val_transform_keys, 85 | split="val", 86 | image_size=self.image_size, 87 | patch_size=self.patch_size, 88 | max_text_len=self.max_text_len, 89 | tokenizer=self.tokenizer, 90 | ) 91 | 92 | def set_test_dataset(self): 93 | self.test_dataset = self.dataset_cls( 94 | data_dir=self.data_dir, 95 | transform_keys=self.val_transform_keys, 96 | split="test", 97 | image_size=self.image_size, 98 | patch_size=self.patch_size, 99 | max_text_len=self.max_text_len, 100 | tokenizer=self.tokenizer, 101 | ) 102 | 103 | def make_val_dset(self, image_only=False): 104 | return self.dataset_cls( 105 | data_dir=self.data_dir, 106 | transform_keys=self.val_transform_keys, 107 | split="test", 108 | image_size=self.image_size, 109 | patch_size=self.patch_size, 110 | max_text_len=self.max_text_len, 111 | image_only=image_only, 112 | tokenizer=self.tokenizer, 113 | ) 114 | 115 | def setup(self, stage): 116 | if not self.setup_flag: 117 | self.set_train_dataset() 118 | self.set_val_dataset() 119 | self.set_test_dataset() 120 | 121 | self.train_dataset.tokenizer = self.tokenizer 122 | self.val_dataset.tokenizer = self.tokenizer 123 | self.test_dataset.tokenizer = self.tokenizer 124 | 125 | self.setup_flag = True 126 | 127 | def train_dataloader(self): 128 | loader = DataLoader( 129 | self.train_dataset, 130 | batch_size=self.batch_size, 131 | shuffle=True, 132 | num_workers=self.num_workers, 133 | pin_memory=True, 134 | collate_fn=self.train_dataset.collate, 135 | drop_last=False, 136 | ) 137 | return loader 138 | 139 | def val_dataloader(self): 140 | loader = DataLoader( 141 | self.val_dataset, 142 | batch_size=self.eval_batch_size, 143 | shuffle=False, 144 | num_workers=self.num_workers, 145 | pin_memory=True, 146 | collate_fn=self.val_dataset.collate, 147 | drop_last=False 148 | ) 149 | return loader 150 | 151 | def test_dataloader(self): 152 | loader = DataLoader( 153 | self.test_dataset, 154 | batch_size=self.eval_batch_size, 155 | shuffle=False, 156 | num_workers=self.num_workers, 157 | pin_memory=True, 158 | collate_fn=self.test_dataset.collate, 159 | drop_last=False 160 | ) 161 | return loader 162 | -------------------------------------------------------------------------------- /src/datamodules/meme_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import MemeDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class MemeDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return MemeDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return MemeDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "meme" 20 | -------------------------------------------------------------------------------- /src/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | 45 | if self.dist: 46 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 47 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 48 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 49 | else: 50 | self.train_sampler = None 51 | self.val_sampler = None 52 | self.test_sampler = None 53 | 54 | def train_dataloader(self): 55 | loader = DataLoader( 56 | self.train_dataset, 57 | batch_size=self.batch_size, 58 | sampler=self.train_sampler, 59 | num_workers=self.num_workers, 60 | collate_fn=self.collate, 61 | ) 62 | return loader 63 | 64 | def val_dataloader(self, batch_size=None): 65 | loader = DataLoader( 66 | self.val_dataset, 67 | batch_size=batch_size if batch_size is not None else self.batch_size, 68 | sampler=self.val_sampler, 69 | num_workers=self.num_workers, 70 | collate_fn=self.collate, 71 | ) 72 | return loader 73 | 74 | def test_dataloader(self): 75 | loader = DataLoader( 76 | self.test_dataset, 77 | batch_size=self.batch_size, 78 | sampler=self.test_sampler, 79 | num_workers=self.num_workers, 80 | collate_fn=self.collate, 81 | ) 82 | return loader 83 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .meme import MemeDataset 2 | -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import io 4 | import pandas as pd 5 | import os 6 | import jsonlines 7 | 8 | from PIL import Image 9 | from ..transforms import keys_to_transforms 10 | 11 | def jsonl_reader(path): 12 | data = [] 13 | with jsonlines.open(path) as reader: 14 | for obj in reader: 15 | data.append(obj) 16 | return data 17 | 18 | class JsonDataset(torch.utils.data.Dataset): 19 | def __init__( 20 | self, 21 | data_dir, 22 | input_filename, 23 | transform_keys, 24 | image_size, 25 | patch_size, 26 | img_key, 27 | text_key, 28 | label_key, 29 | tokenizer=None, 30 | max_text_len=50, 31 | image_only=False, 32 | ): 33 | """ 34 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 35 | transform_keys : keys for generating augmented views of images 36 | """ 37 | assert len(transform_keys) >= 1 38 | super().__init__() 39 | self.data_dir = data_dir 40 | self.image_only = image_only 41 | self.data = jsonl_reader(f"{data_dir}/{input_filename}") 42 | self.img_key = img_key 43 | self.text_key = text_key 44 | self.label_key = label_key 45 | self.transforms = keys_to_transforms(transform_keys, size=image_size) 46 | self.max_text_len = max_text_len 47 | self.image_size = image_size 48 | self.patch_size = patch_size 49 | self.tokenizer = tokenizer 50 | 51 | def __len__(self): 52 | return len(self.data) if self.data else 0 53 | 54 | def get_image(self, idx): 55 | image_features = self.transforms[0](Image.open(f"{self.data_dir}/images/{str(self.data[idx][self.img_key])}")).unsqueeze(0) 56 | return { 57 | "image_features": image_features, # [1, 3, H, W] 58 | "raw_index": idx, 59 | "img_path": f"{self.data_dir}/images/{str(self.data[idx][self.img_key])}", 60 | "img_index": self.data[idx]["id"], 61 | } 62 | 63 | def get_text(self, idx): 64 | text = str(self.data[idx][self.text_key]).lower() 65 | encoding = self.tokenizer( 66 | text, 67 | padding="max_length", 68 | truncation=True, 69 | max_length=self.max_text_len, 70 | return_special_tokens_mask=True, 71 | ) 72 | return { 73 | "text": (text, encoding), 74 | "raw_index": idx, 75 | } 76 | 77 | def get_label(self, idx): 78 | label = self.data[idx][self.label_key][0] 79 | return { 80 | "label": label, 81 | "raw_index": idx, 82 | } 83 | 84 | def get_suite(self, idx): 85 | result = None 86 | while result is None: 87 | try: 88 | ret = dict() 89 | ret.update(self.get_image(idx)) 90 | if not self.image_only: 91 | ret.update(self.get_text(idx)) 92 | ret.update(self.get_label(idx)) 93 | result = True 94 | except Exception as e: 95 | print(f"Error while read file idx {idx} in {self.data_dir}/{str(self.data[idx][self.img_key])} -> {e}") 96 | idx = random.randint(0, len(self.data) - 1) 97 | 98 | return ret 99 | 100 | def collate(self, batch, mlm_collator): 101 | batch_size = len(batch) 102 | keys = set([key for b in batch for key in b.keys()]) 103 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 104 | 105 | batch_image_features = torch.cat(dict_batch["image_features"], dim=0) # [bs, 3, H, W] 106 | dict_batch["image_features"] = batch_image_features 107 | batch_labels = torch.LongTensor(dict_batch["label"]) 108 | dict_batch["labels"] = batch_labels 109 | 110 | txt_keys = [k for k in list(dict_batch.keys()) if "text" in k] 111 | 112 | if len(txt_keys) != 0: 113 | texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys] 114 | encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys] 115 | flatten_encodings = [e for encoding in encodings for e in encoding] 116 | 117 | # Prepare for text encoder 118 | flatten_mlms = mlm_collator(flatten_encodings) 119 | 120 | for i, txt_key in enumerate(txt_keys): 121 | texts, encodings = ( 122 | [d[0] for d in dict_batch[txt_key]], 123 | [d[1] for d in dict_batch[txt_key]], 124 | ) 125 | 126 | mlm_ids, mlm_labels = ( 127 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 128 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 129 | ) 130 | 131 | input_ids = torch.zeros_like(mlm_ids) 132 | attention_mask = torch.zeros_like(mlm_ids) 133 | for _i, encoding in enumerate(encodings): 134 | _input_ids, _attention_mask = ( 135 | torch.tensor(encoding["input_ids"]), 136 | torch.tensor(encoding["attention_mask"]), 137 | ) 138 | input_ids[_i, : len(_input_ids)] = _input_ids 139 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 140 | 141 | dict_batch[txt_key] = texts 142 | dict_batch[f"{txt_key}_ids"] = input_ids 143 | dict_batch[f"{txt_key}_masks"] = attention_mask 144 | return dict_batch 145 | -------------------------------------------------------------------------------- /src/datasets/meme.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import JsonDataset 2 | import io 3 | from PIL import Image 4 | 5 | class MemeDataset(JsonDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | input_filename = "train.jsonl" 12 | elif split == "val": 13 | input_filename = "val.jsonl" 14 | elif split == "test": 15 | input_filename = "test.jsonl" 16 | 17 | img_key = "image" 18 | text_key = "text" 19 | label_key = "labels" 20 | 21 | super().__init__( 22 | *args, 23 | **kwargs, 24 | input_filename=input_filename, 25 | img_key=img_key, 26 | text_key=text_key, 27 | label_key=label_key, 28 | ) 29 | 30 | 31 | def __getitem__(self, index): 32 | suite = self.get_suite(index) 33 | return suite 34 | -------------------------------------------------------------------------------- /src/gadgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/ExplainHM-WWW2024/5c8742286a416b687fe8ce8acba32fc026e91c3e/src/gadgets/__init__.py -------------------------------------------------------------------------------- /src/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.metrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | preds = logits.argmax(dim=-1) 17 | preds = preds[target != -100] 18 | target = target[target != -100] 19 | if target.numel() == 0: 20 | return 1 21 | 22 | assert preds.shape == target.shape 23 | 24 | self.correct += torch.sum(preds == target) 25 | self.total += target.numel() 26 | 27 | def compute(self): 28 | return self.correct / self.total 29 | 30 | 31 | class F1(Metric): 32 | def __init__(self, dist_sync_on_step=False): 33 | super().__init__(dist_sync_on_step=dist_sync_on_step) 34 | self.add_state("tp", default=torch.tensor(0.0), dist_reduce_fx="sum") 35 | self.add_state("fp", default=torch.tensor(0.0), dist_reduce_fx="sum") 36 | self.add_state("fn", default=torch.tensor(0.0), dist_reduce_fx="sum") 37 | self.add_state("tn", default=torch.tensor(0.0), dist_reduce_fx="sum") 38 | 39 | def update(self, logits, target): 40 | logits, target = ( 41 | logits.detach().to(self.tp.device), 42 | target.detach().to(self.tp.device), 43 | ) 44 | preds = logits.argmax(dim=-1) 45 | preds = preds[target != -100] 46 | target = target[target != -100] 47 | if target.numel() == 0: 48 | return 1 49 | 50 | assert preds.shape == target.shape 51 | 52 | for pred, tar in zip(preds.view(-1), target.view(-1)): 53 | if pred == 1 and tar == 1: 54 | self.tp += 1 55 | elif pred == 1 and tar == 0: 56 | self.fp += 1 57 | elif pred == 0 and tar == 1: 58 | self.fn += 1 59 | elif pred == 0 and tar == 0: 60 | self.tn += 1 61 | 62 | def compute(self): 63 | P = self.tp / (self.tp + self.fp) 64 | R = self.tp / (self.tp + self.fn) 65 | f1 = 2 * P * R / (P + R) 66 | 67 | P = self.tn / (self.tn + self.fn) 68 | R = self.tn / (self.tn + self.fp) 69 | f2 = 2 * P * R / (P + R) 70 | return (f1 + f2) / 2 71 | 72 | 73 | class Scalar(Metric): 74 | def __init__(self, dist_sync_on_step=False): 75 | super().__init__(dist_sync_on_step=dist_sync_on_step) 76 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 77 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 78 | 79 | def update(self, scalar): 80 | if isinstance(scalar, torch.Tensor): 81 | scalar = scalar.detach().to(self.scalar.device) 82 | else: 83 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 84 | self.scalar += scalar 85 | self.total += 1 86 | 87 | def compute(self): 88 | return self.scalar / self.total 89 | 90 | 91 | class VQAScore(Metric): 92 | def __init__(self, dist_sync_on_step=False): 93 | super().__init__(dist_sync_on_step=dist_sync_on_step) 94 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 95 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 96 | 97 | def update(self, logits, target): 98 | logits, target = ( 99 | logits.detach().float().to(self.score.device), 100 | target.detach().float().to(self.score.device), 101 | ) 102 | logits = torch.max(logits, 1)[1] 103 | one_hots = torch.zeros(*target.size()).to(target) 104 | one_hots.scatter_(1, logits.view(-1, 1), 1) 105 | scores = one_hots * target 106 | 107 | self.score += scores.sum() 108 | self.total += len(logits) 109 | 110 | def compute(self): 111 | return self.score / self.total 112 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mm_module import MMTransformerSS -------------------------------------------------------------------------------- /src/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import torch 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert _LOCAL_PROCESS_GROUP is not None 49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 50 | 51 | 52 | def get_local_size() -> int: 53 | """ 54 | Returns: 55 | The size of the per-machine process group, 56 | i.e. the number of processes per machine. 57 | """ 58 | if not dist.is_available(): 59 | return 1 60 | if not dist.is_initialized(): 61 | return 1 62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 63 | 64 | 65 | def is_main_process() -> bool: 66 | return get_rank() == 0 67 | 68 | 69 | def synchronize(): 70 | """ 71 | Helper function to synchronize (barrier) among all processes when 72 | using distributed training 73 | """ 74 | if not dist.is_available(): 75 | return 76 | if not dist.is_initialized(): 77 | return 78 | world_size = dist.get_world_size() 79 | if world_size == 1: 80 | return 81 | dist.barrier() 82 | 83 | 84 | @functools.lru_cache() 85 | def _get_global_gloo_group(): 86 | """ 87 | Return a process group based on gloo backend, containing all the ranks 88 | The result is cached. 89 | """ 90 | if dist.get_backend() == "nccl": 91 | return dist.new_group(backend="gloo") 92 | else: 93 | return dist.group.WORLD 94 | 95 | 96 | def _serialize_to_tensor(data, group): 97 | backend = dist.get_backend(group) 98 | assert backend in ["gloo", "nccl"] 99 | device = torch.device("cpu" if backend == "gloo" else "cuda") 100 | 101 | buffer = pickle.dumps(data) 102 | if len(buffer) > 1024 ** 3: 103 | logger = logging.getLogger(__name__) 104 | logger.warning( 105 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 106 | get_rank(), len(buffer) / (1024 ** 3), device 107 | ) 108 | ) 109 | storage = torch.ByteStorage.from_buffer(buffer) 110 | tensor = torch.ByteTensor(storage).to(device=device) 111 | return tensor 112 | 113 | 114 | def _pad_to_largest_tensor(tensor, group): 115 | """ 116 | Returns: 117 | list[int]: size of the tensor, on each rank 118 | Tensor: padded tensor that has the max size 119 | """ 120 | world_size = dist.get_world_size(group=group) 121 | assert ( 122 | world_size >= 1 123 | ), "comm.gather/all_gather must be called from ranks within the given group!" 124 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 125 | size_list = [ 126 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 127 | for _ in range(world_size) 128 | ] 129 | dist.all_gather(size_list, local_size, group=group) 130 | size_list = [int(size.item()) for size in size_list] 131 | 132 | max_size = max(size_list) 133 | 134 | # we pad the tensor because torch all_gather does not support 135 | # gathering tensors of different shapes 136 | if local_size != max_size: 137 | padding = torch.zeros( 138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 139 | ) 140 | tensor = torch.cat((tensor, padding), dim=0) 141 | return size_list, tensor 142 | 143 | 144 | def all_gather(data, group=None): 145 | """ 146 | Run all_gather on arbitrary picklable data (not necessarily tensors). 147 | 148 | Args: 149 | data: any picklable object 150 | group: a torch process group. By default, will use a group which 151 | contains all ranks on gloo backend. 152 | 153 | Returns: 154 | list[data]: list of data gathered from each rank 155 | """ 156 | if get_world_size() == 1: 157 | return [data] 158 | if group is None: 159 | group = _get_global_gloo_group() 160 | if dist.get_world_size(group) == 1: 161 | return [data] 162 | 163 | tensor = _serialize_to_tensor(data, group) 164 | 165 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 166 | max_size = max(size_list) 167 | 168 | # receiving Tensor from all ranks 169 | tensor_list = [ 170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 171 | for _ in size_list 172 | ] 173 | dist.all_gather(tensor_list, tensor, group=group) 174 | 175 | data_list = [] 176 | for size, tensor in zip(size_list, tensor_list): 177 | buffer = tensor.cpu().numpy().tobytes()[:size] 178 | data_list.append(pickle.loads(buffer)) 179 | 180 | return data_list 181 | 182 | 183 | def gather(data, dst=0, group=None): 184 | """ 185 | Run gather on arbitrary picklable data (not necessarily tensors). 186 | 187 | Args: 188 | data: any picklable object 189 | dst (int): destination rank 190 | group: a torch process group. By default, will use a group which 191 | contains all ranks on gloo backend. 192 | 193 | Returns: 194 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 195 | an empty list. 196 | """ 197 | if get_world_size() == 1: 198 | return [data] 199 | if group is None: 200 | group = _get_global_gloo_group() 201 | if dist.get_world_size(group=group) == 1: 202 | return [data] 203 | rank = dist.get_rank(group=group) 204 | 205 | tensor = _serialize_to_tensor(data, group) 206 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 207 | 208 | # receiving Tensor from all ranks 209 | if rank == dst: 210 | max_size = max(size_list) 211 | tensor_list = [ 212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 213 | for _ in size_list 214 | ] 215 | dist.gather(tensor, tensor_list, dst=dst, group=group) 216 | 217 | data_list = [] 218 | for size, tensor in zip(size_list, tensor_list): 219 | buffer = tensor.cpu().numpy().tobytes()[:size] 220 | data_list.append(pickle.loads(buffer)) 221 | return data_list 222 | else: 223 | dist.gather(tensor, [], dst=dst, group=group) 224 | return [] 225 | 226 | 227 | def shared_random_seed(): 228 | """ 229 | Returns: 230 | int: a random number that is the same across all workers. 231 | If workers need a shared RNG, they can use this shared seed to 232 | create one. 233 | 234 | All workers must call this function, otherwise it will deadlock. 235 | """ 236 | ints = np.random.randint(2 ** 31) 237 | all_ints = all_gather(ints) 238 | return all_ints[0] 239 | 240 | 241 | def reduce_dict(input_dict, average=True): 242 | """ 243 | Reduce the values in the dictionary from all processes so that process with rank 244 | 0 has the reduced results. 245 | 246 | Args: 247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 248 | average (bool): whether to do average or sum 249 | 250 | Returns: 251 | a dict with the same keys as input_dict, after reduction. 252 | """ 253 | world_size = get_world_size() 254 | if world_size < 2: 255 | return input_dict 256 | with torch.no_grad(): 257 | names = [] 258 | values = [] 259 | # sort the keys so that they are consistent across processes 260 | for k in sorted(input_dict.keys()): 261 | names.append(k) 262 | values.append(input_dict[k]) 263 | values = torch.stack(values, dim=0) 264 | dist.reduce(values, dst=0) 265 | if dist.get_rank() == 0 and average: 266 | # only main process gets accumulated, so only divide by 267 | # world_size in this case 268 | values /= world_size 269 | reduced_dict = {k: v for k, v in zip(names, values)} 270 | return reduced_dict 271 | -------------------------------------------------------------------------------- /src/modules/mm_module.py: -------------------------------------------------------------------------------- 1 | from email.errors import NonPrintableDefect 2 | import torch 3 | import torch.nn as nn 4 | import pytorch_lightning as pl 5 | import numpy as np 6 | import random 7 | import json 8 | import jsonlines 9 | 10 | from torch import distributed as dist 11 | from transformers import CLIPVisionModel, T5Tokenizer 12 | 13 | from . import mm_utils 14 | from . import objectives 15 | from . import dist_utils 16 | from .t5_model import T5ForMultimodalGeneration 17 | 18 | torch.backends.cudnn.enabled = False 19 | 20 | class MMTransformerSS(pl.LightningModule): 21 | def __init__(self, config): 22 | super().__init__() 23 | self.save_hyperparameters() 24 | #self.mode = self.hparams.config["mode"] 25 | 26 | if torch.distributed.is_initialized(): 27 | if torch.distributed.get_rank() == 0: 28 | CLIPVisionModel.from_pretrained(config["vit"]) 29 | T5ForMultimodalGeneration.from_pretrained(config['tokenizer']) 30 | torch.distributed.barrier() 31 | 32 | ##################################################################################### 33 | self.image_transformer = CLIPVisionModel.from_pretrained(config["vit"]) 34 | self.text_transformer = T5ForMultimodalGeneration.from_pretrained( 35 | config['tokenizer'], 36 | img_hsz=config["input_image_embed_size"], 37 | ) 38 | self.clf = nn.Linear(config["input_text_embed_size"], 2) 39 | ##################################################################################### 40 | for param in self.image_transformer.parameters(): 41 | param.requires_grad = False 42 | 43 | mm_utils.set_metrics(self) 44 | self.current_tasks = list() 45 | 46 | # ===================== load model ====================== 47 | if self.hparams.config["load_path"] != "": 48 | ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") 49 | state_dict = ckpt["state_dict"] 50 | self.load_state_dict(state_dict, strict=False) 51 | 52 | self.pred_result = {} 53 | 54 | def encode_image( 55 | self, 56 | image_features, 57 | ): 58 | last_hidden_state = self.image_transformer( 59 | pixel_values=image_features, 60 | ).last_hidden_state 61 | return last_hidden_state 62 | 63 | def infer( 64 | self, 65 | batch, 66 | ): 67 | text_ids = batch[f"text_ids"] 68 | label_ids = batch[f"labels"] 69 | text_masks = batch[f"text_masks"] 70 | image_features = batch[f"image_features"] 71 | 72 | image_features = self.encode_image(image_features) 73 | text_outputs = self.text_transformer( 74 | input_ids=text_ids, 75 | attention_mask=text_masks, 76 | image_ids=image_features, 77 | ) 78 | text_outputs = torch.sum(text_outputs * text_masks.unsqueeze(-1), dim=1) / torch.sum(text_masks, dim=1).unsqueeze(-1) 79 | logits = self.clf(text_outputs) 80 | 81 | ret = { 82 | "logits": logits, 83 | "label_ids": label_ids, 84 | } 85 | 86 | return ret 87 | 88 | def forward(self, batch): 89 | ret = dict() 90 | ret.update(self.infer(batch)) 91 | ret.update(objectives.compute_clf(self, ret)) 92 | return ret 93 | 94 | def training_step(self, batch, batch_idx): 95 | mm_utils.set_task(self) 96 | output = self(batch) 97 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 98 | 99 | return total_loss 100 | 101 | def training_epoch_end(self, outs): 102 | mm_utils.epoch_wrapup(self) 103 | 104 | def validation_step(self, batch, batch_idx): 105 | mm_utils.set_task(self) 106 | output = self(batch) 107 | 108 | def validation_epoch_end(self, outs): 109 | mm_utils.epoch_wrapup(self) 110 | 111 | def test_step(self, batch, batch_idx): 112 | mm_utils.set_task(self) 113 | output = self(batch) 114 | 115 | def test_epoch_end(self, outs): 116 | mm_utils.epoch_wrapup(self) 117 | 118 | def configure_optimizers(self): 119 | return mm_utils.set_schedule(self) 120 | -------------------------------------------------------------------------------- /src/modules/mm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from transformers.optimization import AdamW 5 | from transformers import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | get_cosine_schedule_with_warmup, 8 | ) 9 | from .dist_utils import all_gather 10 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar, F1 11 | 12 | 13 | def set_metrics(pl_module): 14 | for split in ["train", "val", "test"]: 15 | for k, v in pl_module.hparams.config["loss_names"].items(): 16 | if v < 1: 17 | continue 18 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 19 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 20 | setattr(pl_module, f"{split}_{k}_f1", F1()) 21 | 22 | 23 | def epoch_wrapup(pl_module): 24 | phase = "train" if pl_module.training else "val" 25 | the_metric = 0 26 | 27 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 28 | if v < 1: 29 | continue 30 | 31 | if phase == "train": 32 | value = getattr(pl_module, f"train_{loss_name}_accuracy").compute() 33 | pl_module.log(f"{loss_name}/train/accuracy_epoch", value) 34 | getattr(pl_module, f"train_{loss_name}_accuracy").reset() 35 | pl_module.log( 36 | f"{loss_name}/train/f1_epoch", 37 | getattr(pl_module, f"train_{loss_name}_f1").compute() 38 | ) 39 | getattr(pl_module, f"train_{loss_name}_f1").reset() 40 | pl_module.log( 41 | f"{loss_name}/train/loss_epoch", 42 | getattr(pl_module, f"train_{loss_name}_loss").compute(), 43 | ) 44 | getattr(pl_module, f"train_{loss_name}_loss").reset() 45 | else: 46 | value = getattr(pl_module, f"test_{loss_name}_accuracy").compute() 47 | pl_module.log(f"{loss_name}/test/accuracy_epoch", value) 48 | getattr(pl_module, f"test_{loss_name}_accuracy").reset() 49 | pl_module.log( 50 | f"{loss_name}/test/f1_epoch", 51 | getattr(pl_module, f"test_{loss_name}_f1").compute() 52 | ) 53 | getattr(pl_module, f"test_{loss_name}_f1").reset() 54 | pl_module.log( 55 | f"{loss_name}/test/loss_epoch", 56 | getattr(pl_module, f"test_{loss_name}_loss").compute(), 57 | ) 58 | getattr(pl_module, f"test_{loss_name}_loss").reset() 59 | 60 | value = getattr(pl_module, f"val_{loss_name}_accuracy").compute() 61 | pl_module.log(f"{loss_name}/val/accuracy_epoch", value) 62 | getattr(pl_module, f"val_{loss_name}_accuracy").reset() 63 | pl_module.log( 64 | f"{loss_name}/val/f1_epoch", 65 | getattr(pl_module, f"val_{loss_name}_f1").compute() 66 | ) 67 | getattr(pl_module, f"val_{loss_name}_f1").reset() 68 | pl_module.log( 69 | f"{loss_name}/val/loss_epoch", 70 | getattr(pl_module, f"val_{loss_name}_loss").compute(), 71 | ) 72 | getattr(pl_module, f"val_{loss_name}_loss").reset() 73 | the_metric = the_metric + value 74 | 75 | pl_module.log(f"{phase}/the_metric", the_metric) 76 | 77 | 78 | def check_non_acc_grad(pl_module): 79 | if pl_module.token_type_embeddings.weight.grad is None: 80 | return True 81 | else: 82 | grad = pl_module.token_type_embeddings.weight.grad 83 | return (grad.sum() == 0).item() 84 | 85 | 86 | def set_task(pl_module): 87 | pl_module.current_tasks = [ 88 | k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1 89 | ] 90 | return 91 | 92 | def set_schedule(pl_module): 93 | lr = pl_module.hparams.config["learning_rate"] 94 | wd = pl_module.hparams.config["weight_decay"] 95 | 96 | no_decay = [ 97 | "bias", 98 | "LayerNorm.bias", 99 | "LayerNorm.weight", 100 | "norm.bias", 101 | "norm.weight", 102 | "norm1.bias", 103 | "norm1.weight", 104 | "norm2.bias", 105 | "norm2.weight", 106 | ] 107 | end_lr = pl_module.hparams.config["end_lr"] 108 | decay_power = pl_module.hparams.config["decay_power"] 109 | optim_type = pl_module.hparams.config["optim_type"] 110 | optimizer_grouped_parameters = [ 111 | { 112 | "params": [ 113 | p 114 | for n, p in pl_module.named_parameters() 115 | if not any(nd in n for nd in no_decay) 116 | ], 117 | "weight_decay": wd, 118 | "lr": lr, 119 | }, 120 | { 121 | "params": [ 122 | p 123 | for n, p in pl_module.named_parameters() 124 | if any(nd in n for nd in no_decay) 125 | ], 126 | "weight_decay": 0.0, 127 | "lr": lr, 128 | }, 129 | ] 130 | 131 | if optim_type == "adamw": 132 | optimizer = AdamW( 133 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.999) 134 | ) 135 | elif optim_type == "adam": 136 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 137 | elif optim_type == "sgd": 138 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 139 | 140 | if pl_module.trainer.max_steps is None: 141 | max_steps = ( 142 | len(pl_module.trainer.datamodule.train_dataloader()) 143 | * pl_module.trainer.max_epochs 144 | // pl_module.trainer.accumulate_grad_batches 145 | ) 146 | else: 147 | max_steps = pl_module.trainer.max_steps 148 | 149 | warmup_steps = pl_module.hparams.config["warmup_steps"] 150 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 151 | warmup_steps = int(max_steps * warmup_steps) 152 | 153 | if decay_power == "cosine": 154 | scheduler = get_cosine_schedule_with_warmup( 155 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 156 | ) 157 | else: 158 | scheduler = get_polynomial_decay_schedule_with_warmup( 159 | optimizer, 160 | num_warmup_steps=warmup_steps, 161 | num_training_steps=max_steps, 162 | lr_end=end_lr, 163 | power=decay_power, 164 | ) 165 | 166 | sched = {"scheduler": scheduler, "interval": "step"} 167 | 168 | return ( 169 | [optimizer], 170 | [sched], 171 | ) 172 | -------------------------------------------------------------------------------- /src/modules/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import functools 9 | import numpy as np 10 | 11 | from torch.utils.data.distributed import DistributedSampler 12 | from einops import rearrange 13 | from .dist_utils import all_gather 14 | 15 | SMALL_NUM = np.log(1e-45) 16 | 17 | def compute_clf(pl_module, ret): 18 | logits = ret["logits"] 19 | label_ids = ret["label_ids"] 20 | clf_loss = F.cross_entropy(logits, label_ids.view(-1)) 21 | 22 | new_ret = { 23 | f"clf_loss": clf_loss 24 | } 25 | 26 | phase = "train" if pl_module.training else "val" 27 | loss = getattr(pl_module, f"{phase}_clf_loss")(new_ret["clf_loss"]) 28 | acc = getattr(pl_module, f"{phase}_clf_accuracy")( 29 | ret["logits"], ret["label_ids"] 30 | ) 31 | f1 = getattr(pl_module, f"{phase}_clf_f1")( 32 | ret["logits"], ret["label_ids"] 33 | ) 34 | pl_module.log(f"clf/{phase}/loss", loss) 35 | pl_module.log(f"clf/{phase}/accuracy", acc) 36 | pl_module.log(f"clf/{phase}/f1", f1) 37 | return new_ret 38 | 39 | 40 | def init_weights(module): 41 | if isinstance(module, (nn.Linear, nn.Embedding)): 42 | module.weight.data.normal_(mean=0.0, std=0.02) 43 | elif isinstance(module, nn.LayerNorm): 44 | module.bias.data.zero_() 45 | module.weight.data.fill_(1.0) 46 | 47 | if isinstance(module, nn.Linear) and module.bias is not None: 48 | module.bias.data.zero_() -------------------------------------------------------------------------------- /src/modules/t5_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/huggingface/transformers 3 | ''' 4 | 5 | from transformers import T5Config, T5ForConditionalGeneration 6 | from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5EncoderModel, T5PreTrainedModel, T5Block, T5LayerNorm 7 | import copy 8 | import math 9 | import os 10 | import warnings 11 | from typing import Optional, Tuple, Union 12 | import torch 13 | from torch import nn 14 | from torch.nn import CrossEntropyLoss 15 | from transformers.modeling_outputs import ( 16 | BaseModelOutput, 17 | Seq2SeqLMOutput, 18 | BaseModelOutputWithPastAndCrossAttentions 19 | ) 20 | from transformers.utils import ( 21 | logging, 22 | ) 23 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 24 | from torch.utils.checkpoint import checkpoint 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | class SimpleCrossAttention(nn.Module): 30 | def __init__(self, img_hsz, txt_hsz): 31 | super().__init__() 32 | self.proj_q = nn.Linear(txt_hsz, txt_hsz // 2) 33 | self.proj_k = nn.Linear(img_hsz, txt_hsz // 2) 34 | self.proj_v = copy.deepcopy(self.proj_k) 35 | self.proj_o = nn.Linear(txt_hsz // 2, txt_hsz) 36 | self.d_h = txt_hsz // 2 37 | 38 | def forward(self, img, txt): 39 | q, k, v = self.proj_q(txt), self.proj_k(img), self.proj_v(img) 40 | score = torch.softmax(torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_h), dim=-1) 41 | o = torch.matmul(score, v) 42 | return self.proj_o(o) 43 | 44 | 45 | class T5Stack(T5PreTrainedModel): 46 | def __init__(self, config, embed_tokens=None, img_hsz=512): 47 | super().__init__(config) 48 | 49 | self.embed_tokens = embed_tokens 50 | self.is_decoder = config.is_decoder 51 | 52 | self.block = nn.ModuleList( 53 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 54 | ) 55 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 56 | self.dropout = nn.Dropout(config.dropout_rate) 57 | 58 | if not self.is_decoder: 59 | self.ca_block = nn.ModuleList( 60 | [SimpleCrossAttention(img_hsz, config.hidden_size) for _ in range(config.num_layers)] 61 | ) 62 | 63 | # Initialize weights and apply final processing 64 | self.post_init() 65 | # Model parallel 66 | self.model_parallel = False 67 | self.device_map = None 68 | self.gradient_checkpointing = False 69 | self.image_ids = None 70 | 71 | def parallelize(self, device_map=None): 72 | # Check validity of device_map 73 | self.device_map = ( 74 | get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map 75 | ) 76 | assert_device_map(self.device_map, len(self.block)) 77 | self.model_parallel = True 78 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 79 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 80 | # Load onto devices 81 | for k, v in self.device_map.items(): 82 | for layer in v: 83 | cuda_device = "cuda:" + str(k) 84 | self.block[layer] = self.block[layer].to(cuda_device) 85 | 86 | # Set embed_tokens to first layer 87 | self.embed_tokens = self.embed_tokens.to(self.first_device) 88 | # Set final layer norm to last device 89 | self.final_layer_norm = self.final_layer_norm.to(self.last_device) 90 | 91 | def deparallelize(self): 92 | self.model_parallel = False 93 | self.device_map = None 94 | self.first_device = "cpu" 95 | self.last_device = "cpu" 96 | for i in range(len(self.block)): 97 | self.block[i] = self.block[i].to("cpu") 98 | self.embed_tokens = self.embed_tokens.to("cpu") 99 | self.final_layer_norm = self.final_layer_norm.to("cpu") 100 | torch.cuda.empty_cache() 101 | 102 | def get_input_embeddings(self): 103 | return self.embed_tokens 104 | 105 | def set_input_embeddings(self, new_embeddings): 106 | self.embed_tokens = new_embeddings 107 | 108 | def forward( 109 | self, 110 | input_ids=None, 111 | attention_mask=None, 112 | encoder_hidden_states=None, 113 | encoder_attention_mask=None, 114 | inputs_embeds=None, 115 | head_mask=None, 116 | cross_attn_head_mask=None, 117 | past_key_values=None, 118 | use_cache=None, 119 | output_attentions=None, 120 | output_hidden_states=None, 121 | return_dict=None, 122 | image_ids=None, 123 | ): 124 | if image_ids is None: 125 | image_ids = self.image_ids 126 | # Model parallel 127 | if self.model_parallel: 128 | torch.cuda.set_device(self.first_device) 129 | self.embed_tokens = self.embed_tokens.to(self.first_device) 130 | use_cache = use_cache if use_cache is not None else self.config.use_cache 131 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 132 | output_hidden_states = ( 133 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 134 | ) 135 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 136 | 137 | if input_ids is not None and inputs_embeds is not None: 138 | err_msg_prefix = "decoder_" if self.is_decoder else "" 139 | raise ValueError( 140 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 141 | ) 142 | elif input_ids is not None: 143 | input_shape = input_ids.size() 144 | input_ids = input_ids.view(-1, input_shape[-1]) 145 | elif inputs_embeds is not None: 146 | input_shape = inputs_embeds.size()[:-1] 147 | else: 148 | err_msg_prefix = "decoder_" if self.is_decoder else "" 149 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 150 | 151 | if inputs_embeds is None: 152 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 153 | inputs_embeds = self.embed_tokens(input_ids) 154 | 155 | batch_size, seq_length = input_shape 156 | 157 | # required mask seq length can be calculated via length of past 158 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 159 | 160 | if use_cache is True: 161 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 162 | 163 | if attention_mask is None: 164 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 165 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 166 | encoder_seq_length = encoder_hidden_states.shape[1] 167 | encoder_attention_mask = torch.ones( 168 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 169 | ) 170 | 171 | # initialize past_key_values with `None` if past does not exist 172 | if past_key_values is None: 173 | past_key_values = [None] * len(self.block) 174 | 175 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 176 | # ourselves in which case we just need to make it broadcastable to all heads. 177 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) 178 | 179 | # If a 2D or 3D attention mask is provided for the cross-attention 180 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 181 | if self.is_decoder and encoder_hidden_states is not None: 182 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 183 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 184 | if encoder_attention_mask is None: 185 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 186 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 187 | else: 188 | encoder_extended_attention_mask = None 189 | 190 | # Prepare head mask if needed 191 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 192 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 193 | present_key_value_states = () if use_cache else None 194 | all_hidden_states = () if output_hidden_states else None 195 | all_attentions = () if output_attentions else None 196 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 197 | position_bias = None 198 | encoder_decoder_position_bias = None 199 | 200 | hidden_states = self.dropout(inputs_embeds) 201 | 202 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 203 | layer_head_mask = head_mask[i] 204 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 205 | # Model parallel 206 | if self.model_parallel: 207 | torch.cuda.set_device(hidden_states.device) 208 | # Ensure that attention_mask is always on the same device as hidden_states 209 | if attention_mask is not None: 210 | attention_mask = attention_mask.to(hidden_states.device) 211 | if position_bias is not None: 212 | position_bias = position_bias.to(hidden_states.device) 213 | if encoder_hidden_states is not None: 214 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 215 | if encoder_extended_attention_mask is not None: 216 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 217 | if encoder_decoder_position_bias is not None: 218 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 219 | if layer_head_mask is not None: 220 | layer_head_mask = layer_head_mask.to(hidden_states.device) 221 | if cross_attn_layer_head_mask is not None: 222 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 223 | if output_hidden_states: 224 | all_hidden_states = all_hidden_states + (hidden_states,) 225 | 226 | if self.gradient_checkpointing and self.training: 227 | if use_cache: 228 | logger.warning( 229 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 230 | ) 231 | use_cache = False 232 | 233 | def create_custom_forward(module): 234 | def custom_forward(*inputs): 235 | return tuple(module(*inputs, use_cache, output_attentions)) 236 | 237 | return custom_forward 238 | 239 | layer_outputs = checkpoint( 240 | create_custom_forward(layer_module), 241 | hidden_states, 242 | extended_attention_mask, 243 | position_bias, 244 | encoder_hidden_states, 245 | encoder_extended_attention_mask, 246 | encoder_decoder_position_bias, 247 | layer_head_mask, 248 | cross_attn_layer_head_mask, 249 | None, # past_key_value is always None with gradient checkpointing 250 | ) 251 | else: 252 | layer_outputs = layer_module( 253 | hidden_states, 254 | attention_mask=extended_attention_mask, 255 | position_bias=position_bias, 256 | encoder_hidden_states=encoder_hidden_states, 257 | encoder_attention_mask=encoder_extended_attention_mask, 258 | encoder_decoder_position_bias=encoder_decoder_position_bias, 259 | layer_head_mask=layer_head_mask, 260 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 261 | past_key_value=past_key_value, 262 | use_cache=use_cache, 263 | output_attentions=output_attentions, 264 | ) 265 | 266 | # layer_outputs is a tuple with: 267 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 268 | if use_cache is False: 269 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 270 | 271 | hidden_states, present_key_value_state = layer_outputs[:2] 272 | 273 | ###################### 274 | # if not self.is_decoder: 275 | # hidden_states = hidden_states + self.ca_block[i](image_ids, hidden_states) 276 | ##################### 277 | 278 | # We share the position biases between the layers - the first layer store them 279 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 280 | # (cross-attention position bias), (cross-attention weights) 281 | position_bias = layer_outputs[2] 282 | if self.is_decoder and encoder_hidden_states is not None: 283 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 284 | # append next layer key value states 285 | if use_cache: 286 | present_key_value_states = present_key_value_states + (present_key_value_state,) 287 | 288 | if output_attentions: 289 | all_attentions = all_attentions + (layer_outputs[3],) 290 | if self.is_decoder: 291 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 292 | 293 | # Model Parallel: If it's the last layer for that device, put things on the next device 294 | if self.model_parallel: 295 | for k, v in self.device_map.items(): 296 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 297 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 298 | 299 | hidden_states = self.final_layer_norm(hidden_states) 300 | hidden_states = self.dropout(hidden_states) 301 | 302 | # Add last layer 303 | if output_hidden_states: 304 | all_hidden_states = all_hidden_states + (hidden_states,) 305 | 306 | if not return_dict: 307 | return tuple( 308 | v 309 | for v in [ 310 | hidden_states, 311 | present_key_value_states, 312 | all_hidden_states, 313 | all_attentions, 314 | all_cross_attentions, 315 | ] 316 | if v is not None 317 | ) 318 | return BaseModelOutputWithPastAndCrossAttentions( 319 | last_hidden_state=hidden_states, 320 | past_key_values=present_key_value_states, 321 | hidden_states=all_hidden_states, 322 | attentions=all_attentions, 323 | cross_attentions=all_cross_attentions, 324 | ) 325 | 326 | def update_image_ids(self, img): 327 | self.image_ids = img 328 | 329 | 330 | class T5ForMultimodalGeneration(T5ForConditionalGeneration): 331 | _keys_to_ignore_on_load_missing = [ 332 | r"encoder.embed_tokens.weight", 333 | r"decoder.embed_tokens.weight", 334 | r"lm_head.weight", 335 | ] 336 | _keys_to_ignore_on_load_unexpected = [ 337 | r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", 338 | ] 339 | 340 | def __init__(self, config: T5Config, img_hsz=512): 341 | super().__init__(config) 342 | self.model_dim = config.d_model 343 | 344 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 345 | 346 | # ############ 347 | # self.ca = SimpleCrossAttention(img_hsz, config.hidden_size) 348 | # ############ 349 | 350 | encoder_config = copy.deepcopy(config) 351 | encoder_config.is_decoder = False 352 | encoder_config.use_cache = False 353 | encoder_config.is_encoder_decoder = False 354 | self.encoder = T5Stack(encoder_config, self.shared, img_hsz=img_hsz) 355 | 356 | decoder_config = copy.deepcopy(config) 357 | decoder_config.is_decoder = True 358 | decoder_config.is_encoder_decoder = False 359 | decoder_config.num_layers = config.num_decoder_layers 360 | self.decoder = T5Stack(decoder_config, self.shared) 361 | 362 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 363 | 364 | # Initialize weights and apply final processing 365 | self.post_init() 366 | 367 | # Model parallel 368 | self.model_parallel = False 369 | self.device_map = None 370 | 371 | self.image_ids = None 372 | 373 | def update_image_ids(self, img): 374 | self.image_ids = img 375 | 376 | def forward( 377 | self, 378 | input_ids: Optional[torch.LongTensor] = None, 379 | image_ids=None, 380 | attention_mask: Optional[torch.FloatTensor] = None, 381 | decoder_input_ids: Optional[torch.LongTensor] = None, 382 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 383 | head_mask: Optional[torch.FloatTensor] = None, 384 | decoder_head_mask: Optional[torch.FloatTensor] = None, 385 | cross_attn_head_mask: Optional[torch.Tensor] = None, 386 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, 387 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 388 | inputs_embeds: Optional[torch.FloatTensor] = None, 389 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 390 | labels: Optional[torch.LongTensor] = None, 391 | use_cache: Optional[bool] = None, 392 | output_attentions: Optional[bool] = None, 393 | output_hidden_states: Optional[bool] = None, 394 | return_dict: Optional[bool] = None, 395 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: 396 | use_cache = use_cache if use_cache is not None else self.config.use_cache 397 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 398 | 399 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 400 | if head_mask is not None and decoder_head_mask is None: 401 | if self.config.num_layers == self.config.num_decoder_layers: 402 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 403 | decoder_head_mask = head_mask 404 | 405 | ############### 406 | if image_ids is None: 407 | image_ids = self.image_ids 408 | ############## 409 | 410 | # Encode if needed (training, first prediction pass) 411 | if encoder_outputs is None: 412 | # Convert encoder inputs in embeddings if needed 413 | encoder_outputs = self.encoder( 414 | input_ids=input_ids, 415 | attention_mask=attention_mask, 416 | inputs_embeds=inputs_embeds, 417 | head_mask=head_mask, 418 | output_attentions=output_attentions, 419 | output_hidden_states=output_hidden_states, 420 | return_dict=return_dict, 421 | image_ids=image_ids, 422 | ) 423 | 424 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 425 | encoder_outputs = BaseModelOutput( 426 | last_hidden_state=encoder_outputs[0], 427 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 428 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 429 | ) 430 | 431 | 432 | hidden_states = encoder_outputs[0] 433 | 434 | # ########### 435 | # if image_ids is None: 436 | # image_ids = self.image_ids 437 | # hidden_states = hidden_states + self.ca(image_ids, hidden_states) 438 | # ########### 439 | 440 | if self.model_parallel: 441 | torch.cuda.set_device(self.decoder.first_device) 442 | 443 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 444 | # get decoder inputs from shifting lm labels to the right 445 | decoder_input_ids = self._shift_right(labels) 446 | 447 | # Set device for model parallelism 448 | if self.model_parallel: 449 | torch.cuda.set_device(self.decoder.first_device) 450 | hidden_states = hidden_states.to(self.decoder.first_device) 451 | if decoder_input_ids is not None: 452 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 453 | if attention_mask is not None: 454 | attention_mask = attention_mask.to(self.decoder.first_device) 455 | if decoder_attention_mask is not None: 456 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 457 | 458 | # Decode 459 | decoder_outputs = self.decoder( 460 | input_ids=input_ids, 461 | attention_mask=attention_mask, 462 | inputs_embeds=decoder_inputs_embeds, 463 | past_key_values=past_key_values, 464 | encoder_hidden_states=hidden_states, 465 | encoder_attention_mask=attention_mask, 466 | head_mask=decoder_head_mask, 467 | cross_attn_head_mask=cross_attn_head_mask, 468 | use_cache=use_cache, 469 | output_attentions=output_attentions, 470 | output_hidden_states=output_hidden_states, 471 | return_dict=return_dict, 472 | ) 473 | 474 | sequence_output = decoder_outputs[0] 475 | return sequence_output 476 | -------------------------------------------------------------------------------- /src/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | vit_transform, 5 | vit_transform_randaug, 6 | imagenet_transform, 7 | imagenet_transform_randaug, 8 | clip_transform, 9 | clip_transform_randaug, 10 | ) 11 | 12 | _transforms = { 13 | "pixelbert": pixelbert_transform, 14 | "pixelbert_randaug": pixelbert_transform_randaug, 15 | "vit": vit_transform, 16 | "vit_randaug": vit_transform_randaug, 17 | "imagenet": imagenet_transform, 18 | "imagenet_randaug": imagenet_transform_randaug, 19 | "clip": clip_transform, 20 | "clip_randaug": clip_transform_randaug, 21 | } 22 | 23 | def keys_to_transforms(keys: list, size=224): 24 | return [_transforms[key](size=size) for key in keys] 25 | -------------------------------------------------------------------------------- /src/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.0: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.0)) 136 | y0 = int(max(0, y0 - v / 2.0)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | # (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0.0, 0.3), 197 | (ShearY, 0.0, 0.3), 198 | # (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0.0, 100), 200 | (TranslateYabs, 0.0, 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = ( 220 | self.eigvec.type_as(img) 221 | .clone() 222 | .mul(alpha.view(1, 3).expand(3, 3)) 223 | .mul(self.eigval.view(1, 3).expand(3, 3)) 224 | .sum(1) 225 | .squeeze() 226 | ) 227 | 228 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 229 | 230 | 231 | class CutoutDefault(object): 232 | """ 233 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 234 | """ 235 | 236 | def __init__(self, length): 237 | self.length = length 238 | 239 | def __call__(self, img): 240 | h, w = img.size(1), img.size(2) 241 | mask = np.ones((h, w), np.float32) 242 | y = np.random.randint(h) 243 | x = np.random.randint(w) 244 | 245 | y1 = np.clip(y - self.length // 2, 0, h) 246 | y2 = np.clip(y + self.length // 2, 0, h) 247 | x1 = np.clip(x - self.length // 2, 0, w) 248 | x2 = np.clip(x + self.length // 2, 0, w) 249 | 250 | mask[y1:y2, x1:x2] = 0.0 251 | mask = torch.from_numpy(mask) 252 | mask = mask.expand_as(img) 253 | img *= mask 254 | return img 255 | 256 | 257 | class RandAugment: 258 | def __init__(self, n, m): 259 | self.n = n 260 | self.m = m # [0, 30] 261 | self.augment_list = augment_list() 262 | 263 | def __call__(self, img): 264 | ops = random.choices(self.augment_list, k=self.n) 265 | for op, minval, maxval in ops: 266 | val = (float(self.m) / 30) * float(maxval - minval) + minval 267 | img = op(img, val) 268 | 269 | return img 270 | -------------------------------------------------------------------------------- /src/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | imagenet_normalize, 4 | MinMaxResize, 5 | ) 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from .randaug import RandAugment 10 | 11 | 12 | def pixelbert_transform(size=800): 13 | longer = int((1333 / 800) * size) 14 | return transforms.Compose( 15 | [ 16 | MinMaxResize(shorter=size, longer=longer), 17 | transforms.ToTensor(), 18 | inception_normalize, 19 | ] 20 | ) 21 | 22 | def pixelbert_transform_randaug(size=800): 23 | longer = int((1333 / 800) * size) 24 | trs = transforms.Compose( 25 | [ 26 | MinMaxResize(shorter=size, longer=longer), 27 | transforms.ToTensor(), 28 | inception_normalize, 29 | ] 30 | ) 31 | trs.transforms.insert(0, RandAugment(2, 9)) 32 | return trs 33 | 34 | def imagenet_transform(size=800): 35 | return transforms.Compose( 36 | [ 37 | Resize(size, interpolation=Image.BICUBIC), 38 | CenterCrop(size), 39 | transforms.ToTensor(), 40 | imagenet_normalize, 41 | ] 42 | ) 43 | 44 | def imagenet_transform_randaug(size=800): 45 | trs = transforms.Compose( 46 | [ 47 | Resize(size, interpolation=Image.BICUBIC), 48 | CenterCrop(size), 49 | transforms.ToTensor(), 50 | imagenet_normalize, 51 | ] 52 | ) 53 | trs.transforms.insert(0, RandAugment(2, 9)) 54 | return trs 55 | 56 | def vit_transform(size=800): 57 | return transforms.Compose( 58 | [ 59 | Resize(size, interpolation=Image.BICUBIC), 60 | CenterCrop(size), 61 | lambda image: image.convert("RGB"), 62 | transforms.ToTensor(), 63 | inception_normalize, 64 | ] 65 | ) 66 | 67 | def vit_transform_randaug(size=800): 68 | trs = transforms.Compose( 69 | [ 70 | Resize(size, interpolation=Image.BICUBIC), 71 | CenterCrop(size), 72 | lambda image: image.convert("RGB"), 73 | transforms.ToTensor(), 74 | inception_normalize, 75 | ] 76 | ) 77 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 78 | trs.transforms.insert(0, RandAugment(2, 9)) 79 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 80 | return trs 81 | 82 | def clip_transform(size): 83 | return Compose([ 84 | Resize(size, interpolation=Image.BICUBIC), 85 | CenterCrop(size), 86 | lambda image: image.convert("RGB"), 87 | ToTensor(), 88 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 89 | ]) 90 | 91 | def clip_transform_randaug(size): 92 | trs = Compose([ 93 | Resize(size, interpolation=Image.BICUBIC), 94 | CenterCrop(size), 95 | lambda image: image.convert("RGB"), 96 | ToTensor(), 97 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 98 | ]) 99 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 100 | trs.transforms.insert(0, RandAugment(2, 9)) 101 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 102 | return trs 103 | 104 | -------------------------------------------------------------------------------- /src/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 61 | ) 62 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/ExplainHM-WWW2024/5c8742286a416b687fe8ce8acba32fc026e91c3e/src/utils/__init__.py --------------------------------------------------------------------------------