├── .gitignore ├── LICENSE ├── README.md ├── docs ├── examples │ ├── 0000001.txt │ └── annt.json └── figures │ ├── arch.png │ ├── diagram.png │ ├── image_generation.png │ ├── interleaved.png │ ├── segment.png │ └── text_generation.png ├── evaluate.py ├── inference.py ├── mm_interleaved ├── __init__.py ├── configs │ └── release │ │ ├── deepspeed_zero1.json │ │ ├── mm_eval.yaml │ │ ├── mm_inference.yaml │ │ └── mm_pretrain.yaml ├── custom_datasets │ ├── __init__.py │ ├── ade20k.py │ ├── ade20k_preparation.py │ ├── caption_datasets.py │ ├── clip_itp.py │ ├── collator.py │ ├── collator_sft.py │ ├── flintstones.py │ ├── grounding_datasets.py │ ├── image2paragraph.py │ ├── laion_wds.py │ ├── lncoco.py │ ├── loader.py │ ├── mix_dataset.py │ ├── mmc4_wds.py │ ├── mscoco.py │ ├── mscoco_karpathy.py │ ├── pororo.py │ ├── sft_datasets.py │ ├── utils.py │ ├── visdial_dense.py │ ├── vist.py │ ├── vqa_datasets.py │ └── wds_utils.py ├── engine │ └── lmm_trainer.py ├── models │ ├── __init__.py │ ├── decoders │ │ ├── decoder_image.py │ │ ├── decoder_text.py │ │ ├── modeling_llama_mmfs.py │ │ ├── perceiver.py │ │ ├── sd.py │ │ └── sd_mmfs.py │ ├── encoders │ │ ├── visual_tokenizer.py │ │ └── vit_adapter │ │ │ ├── __init__.py │ │ │ ├── adapter_modules.py │ │ │ ├── clip_vit_hf.py │ │ │ ├── ops │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn_func.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── vit_adapter_hf.py │ │ │ └── xattn.py │ ├── mm_interleaved.py │ └── utils │ │ ├── causal_lm_cascade.py │ │ ├── monkey_patch │ │ ├── __init__.py │ │ ├── beam_search_monkey_patch.py │ │ ├── blip2_qknorm_monkey_patch.py │ │ ├── llama_flash_attn_train_monkey_patch.py │ │ ├── sd_pipeline_monkey_patch.py │ │ └── sd_unet_forward_monkey_patch.py │ │ ├── ops │ │ ├── forward_backward_error.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── mmfs.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── compare_with_data.py │ │ │ ├── create_data.py │ │ │ ├── forward_backward_error.py │ │ │ ├── skip_forward_error.py │ │ │ └── speed_test.py │ │ └── pos_embed.py ├── scripts │ └── download_hf_models.py └── utils │ ├── __init__.py │ ├── caption_collect.py │ ├── clip_sim_score.py │ ├── coco_cap_score.py │ ├── fid_score.py │ ├── grounding_score.py │ ├── inception.py │ ├── misc.py │ ├── parse_args.py │ ├── segm_eval.py │ ├── visdial_metrics.py │ ├── vizwiz_metrics_src │ ├── __init__.py │ ├── vqa.py │ └── vqaEval.py │ ├── vqa_collect.py │ ├── vqa_score.py │ └── vqav2_metrics_src │ ├── __init__.py │ ├── vqa.py │ └── vqaEval.py ├── requirements.txt ├── slurm_run.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | *.pth 3 | batchscript-* 4 | phoenix-slurm-* 5 | .ipynb_checkpoints/ 6 | .idea/ 7 | .vscode/ 8 | 9 | OUTPUT/ 10 | *.tmp 11 | tmp* 12 | ckpts/ 13 | assets 14 | wandb/ 15 | configs/deprecated/ 16 | sh*/ 17 | temp*/ 18 | *ipynb 19 | 20 | *.pt 21 | *_script.py 22 | 23 | -------------------------------------------------------------------------------- /docs/examples/annt.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "sentences": [ 4 | "a kitchen is shown with a variety of items on the counters." 5 | ], 6 | "images": [ 7 | "./assets/dataset/coco/val2014/COCO_val2014_000000384213.jpg" 8 | ], 9 | "sentence_ixs": [ 10 | 0 11 | ], 12 | "image_first": [ 13 | false 14 | ], 15 | "generate_mode": "generate_images", 16 | "num_iter": 1 17 | }, 18 | 19 | { 20 | "sentences": [ 21 | "A photo of" 22 | ], 23 | "images": [ 24 | "./assets/dataset/coco/val2014/COCO_val2014_000000384213.jpg" 25 | ], 26 | "sentence_ixs": [ 27 | 0 28 | ], 29 | "image_first": [ 30 | true 31 | ], 32 | "generate_mode": "generate_texts", 33 | "num_iter": 1 34 | } 35 | ] 36 | -------------------------------------------------------------------------------- /docs/figures/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/arch.png -------------------------------------------------------------------------------- /docs/figures/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/diagram.png -------------------------------------------------------------------------------- /docs/figures/image_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/image_generation.png -------------------------------------------------------------------------------- /docs/figures/interleaved.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/interleaved.png -------------------------------------------------------------------------------- /docs/figures/segment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/segment.png -------------------------------------------------------------------------------- /docs/figures/text_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/docs/figures/text_generation.png -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from mm_interleaved.models.utils.monkey_patch import ( 5 | replace_llama_attn_with_flash_attn, 6 | replace_blip2_attn_with_qknorm_attn, 7 | replace_beam_search, 8 | replace_stable_diffusion_pipeline_call, 9 | replace_stable_diffusion_unet_forward, 10 | ) 11 | 12 | replace_beam_search() 13 | replace_blip2_attn_with_qknorm_attn() 14 | replace_stable_diffusion_unet_forward() 15 | replace_stable_diffusion_pipeline_call() 16 | IS_TRAIN = False 17 | if IS_TRAIN: 18 | replace_llama_attn_with_flash_attn() 19 | 20 | 21 | from mm_interleaved.models import MMInterleaved 22 | from mm_interleaved.custom_datasets.utils import build_dataset 23 | from mm_interleaved.engine.lmm_trainer import LMMTrainer 24 | from mm_interleaved.utils import ArgumentParser, TrainingArguments, init_distributed_mode, load_model_weights 25 | 26 | 27 | def evaluate(trainer: LMMTrainer, config): 28 | print("Eval Start") 29 | if isinstance(trainer.eval_dataset, dict): 30 | eval_datasets = trainer.eval_dataset 31 | else: 32 | eval_datasets = {config.data.val.name: trainer.eval_dataset} 33 | 34 | metrics = {} 35 | for eval_dataset_name, eval_dataset in eval_datasets.items(): 36 | dataset_metrics = trainer.evaluate( 37 | eval_dataset=eval_dataset, 38 | metric_key_prefix=f"eval_{eval_dataset_name}", 39 | ) 40 | print(eval_dataset_name) 41 | print(dataset_metrics) 42 | print("-" * 100) 43 | metrics.update(dataset_metrics) 44 | print("=" * 100) 45 | 46 | if trainer.args.should_save: 47 | metrics_to_save = { 48 | **metrics, 49 | **{"step": trainer.state.global_step}, 50 | } 51 | if trainer.state.epoch is not None: 52 | metrics_to_save["epoch"] = round(trainer.state.epoch, 2) 53 | metrics_save_path = os.path.join(trainer.args.output_dir, "eval_metrics.jsonl") 54 | json_string = json.dumps(metrics_to_save, indent=2, sort_keys=True) + "\n" 55 | with open(metrics_save_path, "a+", encoding="utf-8") as f: 56 | f.write(json_string) 57 | 58 | print("All Finished") 59 | 60 | 61 | def main(): 62 | parser = ArgumentParser(TrainingArguments) 63 | init_distributed_mode() 64 | args = parser.parse_args_with_config_file_into_dataclasses() 65 | train_args, config = args 66 | print(train_args) 67 | print(config) 68 | 69 | print("Data Loading Start") 70 | eval_dataset = build_dataset(config.data.val) 71 | print(eval_dataset) 72 | 73 | print("Model Init Start") 74 | model = MMInterleaved(**config.model) 75 | print(model) 76 | 77 | print("Trainer Init Start") 78 | if isinstance(eval_dataset, dict): 79 | tokenizer = list(eval_dataset.values())[0].tokenizer 80 | else: 81 | tokenizer = eval_dataset.tokenizer 82 | trainer = LMMTrainer( 83 | model=model, 84 | tokenizer=tokenizer, 85 | config=config, 86 | args=train_args, 87 | eval_dataset=eval_dataset, 88 | ) 89 | 90 | if getattr(config, "load_from", None): 91 | load_model_weights(trainer.model, config.load_from) 92 | 93 | evaluate(trainer, config) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /mm_interleaved/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/mm_interleaved/__init__.py -------------------------------------------------------------------------------- /mm_interleaved/configs/release/deepspeed_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "reduce_scatter": false, 6 | "allgather_bucket_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "overlap_comm": true, 9 | "contiguous_gradients": true, 10 | "ignore_unused_parameters": true 11 | }, 12 | "fp16": { 13 | "enabled": "auto", 14 | "auto_cast": true, 15 | "loss_scale": 0.0, 16 | "initial_scale_power": 16, 17 | "loss_scale_window": 250, 18 | "min_loss_scale": 1 19 | }, 20 | "train_batch_size": "auto", 21 | "train_micro_batch_size_per_gpu": "auto", 22 | "wall_clock_breakdown": false, 23 | "gradient_clipping": "auto", 24 | "prescale_gradients": true 25 | } -------------------------------------------------------------------------------- /mm_interleaved/configs/release/mm_eval.yaml: -------------------------------------------------------------------------------- 1 | # Training Arguments 2 | 3 | load_from: OUTPUT/mm_interleaved_pretrain 4 | 5 | fp16: True 6 | per_device_eval_batch_size: 2 7 | dataloader_num_workers: &num_workers 8 8 | data_seed: &data_seed 0 9 | seed: 32 10 | 11 | ## logging 12 | 13 | report_to: ['tensorboard'] 14 | 15 | 16 | # MODEL 17 | 18 | model: 19 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 20 | num_img_token: &img_len 64 21 | 22 | visual_tokenizer_config: 23 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 24 | perceiver_config: 25 | num_queries: 64 26 | hidden_size: 768 27 | encoder_hidden_size: 1024 28 | cross_attention_frequency: 2 29 | num_hidden_layers: 12 30 | num_attention_heads: 12 31 | qk_normalization: True 32 | image_decoder_config: 33 | pretrained_model_name_or_path: ./assets/stabilityai/stable-diffusion-2-1-base 34 | sd_base_seed: 30_000 35 | sd_use_random_seed: True 36 | perceiver_config: 37 | num_queries: 77 38 | hidden_size: 1024 39 | encoder_hidden_size: 5120 40 | cross_attention_frequency: 1 41 | num_hidden_layers: 1 42 | num_attention_heads: 16 43 | hidden_dropout_prob: 0. 44 | attention_probs_dropout_prob: 0. 45 | 46 | # DATA 47 | 48 | data: 49 | val: 50 | - name: coco_karpathy 51 | data_root: assets/datasets/coco 52 | annt_root: assets/datasets/coco 53 | phase: test 54 | year: 2014 55 | 56 | collator: ImageTextPairCollator 57 | num_img_token: *img_len 58 | tokenizer_path: *tokenizer_path 59 | collate_mode: generate_texts 60 | 61 | transform: 62 | aug_type: 'numpy' 63 | resolution: &image_size 224 64 | 65 | - name: flickr30k 66 | data_root: assets/datasets/flickr30k/flickr30k-images 67 | annt_file: assets/datasets/flickr30k/test1k.token.coco_format 68 | 69 | collator: ImageTextPairCollator 70 | num_img_token: *img_len 71 | tokenizer_path: *tokenizer_path 72 | collate_mode: generate_texts 73 | 74 | transform: 75 | aug_type: 'numpy' 76 | resolution: *image_size 77 | 78 | - name: nocaps 79 | data_root: assets/datasets/nocaps/images 80 | annt_file: assets/datasets/nocaps/nocaps_val_4500_captions.json 81 | 82 | collator: ImageTextPairCollator 83 | num_img_token: *img_len 84 | tokenizer_path: *tokenizer_path 85 | collate_mode: generate_texts 86 | 87 | transform: 88 | aug_type: 'numpy' 89 | resolution: *image_size 90 | 91 | - name: image2paragraph 92 | data_root: ./assets/datasets/image2paragraph/images/ 93 | annt_root: ./assets/datasets/image2paragraph 94 | phase: test 95 | 96 | collator: ImageTextPairCollator 97 | num_img_token: *img_len 98 | tokenizer_path: *tokenizer_path 99 | collate_mode: generate_texts 100 | generation_kwargs: 101 | max_length: 90 102 | min_length: 90 103 | repetition_penalty: 1.2 104 | instr_prompts: 105 | image: [] 106 | text: [ 107 | "The image depicts", 108 | "{image}Please describe the image in detail.", 109 | "", 110 | ] 111 | 112 | transform: 113 | aug_type: 'numpy' 114 | resolution: *image_size 115 | 116 | - name: visdial 117 | data_root: assets/datasets/visdial 118 | annt_root: assets/datasets/visdial 119 | phase: val 120 | 121 | collator: VisDialCollator 122 | num_img_token: *img_len 123 | tokenizer_path: *tokenizer_path 124 | collate_mode: generate_scores 125 | 126 | transform: 127 | aug_type: 'numpy' 128 | resolution: *image_size 129 | 130 | - name: coco 131 | data_root: assets/datasets/coco 132 | annt_root: assets/datasets/coco 133 | phase: val 134 | year: 2014 135 | total_length: 30_000 136 | rerank_by_clip: True 137 | 138 | collator: ImageTextPairCollator 139 | num_img_token: *img_len 140 | tokenizer_path: *tokenizer_path 141 | collate_mode: generate_images 142 | generation_kwargs: 143 | guidance_scale: 3.5 144 | num_inference_steps: 250 145 | num_validation_images: 8 146 | 147 | transform: 148 | aug_type: 'numpy' 149 | resolution: *image_size 150 | 151 | - name: lncoco 152 | data_root: assets/datasets/coco 153 | annt_root: assets/datasets/lncoco 154 | phase: val 155 | total_length: 30_000 156 | 157 | collator: ImageTextPairCollator 158 | num_img_token: *img_len 159 | tokenizer_path: *tokenizer_path 160 | collate_mode: generate_images 161 | generation_kwargs: 162 | guidance_scale: 3.5 163 | num_inference_steps: 250 164 | num_validation_images: 1 165 | 166 | transform: 167 | aug_type: 'numpy' 168 | resolution: *image_size 169 | 170 | - name: vqav2 171 | data_root: assets/datasets/coco 172 | annt_root: assets/datasets/VQAv2 173 | phase: val 174 | 175 | collator: VQACollator 176 | num_img_token: *img_len 177 | tokenizer_path: *tokenizer_path 178 | collate_mode: generate_vqa 179 | 180 | transform: 181 | aug_type: 'numpy' 182 | resolution: *image_size 183 | 184 | - name: okvqa 185 | data_root: assets/datasets/coco 186 | annt_root: assets/datasets/OK-VQA 187 | phase: val 188 | 189 | collator: VQACollator 190 | num_img_token: *img_len 191 | tokenizer_path: *tokenizer_path 192 | collate_mode: generate_vqa 193 | 194 | transform: 195 | aug_type: 'numpy' 196 | resolution: *image_size 197 | 198 | - name: vizwiz_vqa 199 | data_root: assets/datasets/VizWiz 200 | annt_root: assets/datasets/VizWiz 201 | phase: val 202 | 203 | collator: VQACollator 204 | num_img_token: *img_len 205 | tokenizer_path: *tokenizer_path 206 | collate_mode: generate_vqa 207 | instr_prompts: [ 208 | "The answer is:", 209 | "Based on the image, please answer the question. {image}{question} When the provided information is insufficient, respond with 'Unanswerable'. Please provide an accurate answer within one word.", 210 | "", 211 | ] 212 | 213 | transform: 214 | aug_type: 'numpy' 215 | resolution: *image_size 216 | 217 | - name: textvqa 218 | data_root: assets/datasets/textvqa/train_images 219 | annt_root: assets/datasets/textvqa/TextVQA 220 | phase: val 221 | 222 | collator: VQACollator 223 | num_img_token: *img_len 224 | tokenizer_path: *tokenizer_path 225 | collate_mode: generate_vqa 226 | 227 | transform: 228 | aug_type: 'numpy' 229 | resolution: *image_size 230 | -------------------------------------------------------------------------------- /mm_interleaved/configs/release/mm_inference.yaml: -------------------------------------------------------------------------------- 1 | load_from: ./OUTPUT/mm_interleaved_pretrain 2 | annt_path: ./docs/examples/annt.json 3 | output_dir: ./OUTPUT/mm_inference 4 | 5 | # MODEL 6 | 7 | model: 8 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 9 | num_img_token: &img_len 64 10 | 11 | visual_tokenizer_config: 12 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 13 | perceiver_config: 14 | num_queries: 64 15 | hidden_size: 768 16 | encoder_hidden_size: 1024 17 | cross_attention_frequency: 2 18 | num_hidden_layers: 12 19 | num_attention_heads: 12 20 | qk_normalization: True 21 | image_decoder_config: 22 | pretrained_model_name_or_path: ./assets/stabilityai/stable-diffusion-2-1-base 23 | sd_base_seed: 42 24 | perceiver_config: 25 | num_queries: 77 26 | hidden_size: 1024 27 | encoder_hidden_size: 5120 28 | cross_attention_frequency: 1 29 | num_hidden_layers: 1 30 | num_attention_heads: 16 31 | hidden_dropout_prob: 0. 32 | attention_probs_dropout_prob: 0. 33 | 34 | # INFERENCE 35 | 36 | inference: 37 | tokenizer_path: *tokenizer_path 38 | num_img_token: *img_len 39 | generate_mode: generate_texts 40 | force_gen_image_next: False 41 | force_replace_gen_text: False 42 | auto_end: False 43 | num_iter: 2 44 | 45 | transform: 46 | aug_type: numpy 47 | resolution: 224 48 | 49 | generation_kwargs: 50 | max_length: 90 51 | min_length: 8 52 | num_beams: 1 53 | use_nucleus_sampling: True 54 | repetition_penalty: 1.3 55 | guidance_scale: 7.5 56 | num_inference_steps: 30 57 | num_validation_images: 1 58 | 59 | -------------------------------------------------------------------------------- /mm_interleaved/configs/release/mm_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # Training Arguments 2 | 3 | fp16: True 4 | max_steps: 15_000 5 | per_device_train_batch_size: &per_device_train_batch_size 4 6 | per_device_eval_batch_size: 2 7 | dataloader_num_workers: &num_workers 8 8 | data_seed: &data_seed 0 9 | seed: 32 10 | 11 | ## optimizer & scheduler 12 | 13 | optim: adamw_torch 14 | learning_rate: 1.0e-4 15 | weight_decay: 0.05 16 | adam_beta1: 0.9 17 | adam_beta2: 0.995 18 | adam_epsilon: 1.0e-6 19 | lr_for_random_params_list: [1.0e-4, 1.0e-5, 1.0e-4, 1.0e-5] 20 | wd_for_random_params_list: [0.0, 0.0, null, null] 21 | random_params_list: [llama_cross_attn.gate, sampling_offsets, llama_cross_attn, image_decoder.decoder.unet] 22 | 23 | lr_scheduler_type: "cosine" 24 | warmup_steps: 1_000 25 | 26 | ## evaluation & saving 27 | 28 | evaluation_strategy: "steps" 29 | eval_steps: 1_000 30 | save_strategy: "steps" 31 | save_steps: 1_000 32 | save_total_limit: 5 33 | fp16_full_eval: false 34 | 35 | generate_mode: generate_both 36 | 37 | ## logging 38 | 39 | report_to: ['tensorboard'] 40 | logging_steps: 10 41 | disable_tqdm: False 42 | log_level: info 43 | 44 | ## misc 45 | 46 | tf32: True 47 | ddp_find_unused_parameters: False 48 | 49 | ## deepspeed 50 | 51 | deepspeed: './mm_interleaved/configs/release/deepspeed_zero1.json' 52 | 53 | 54 | # MODEL 55 | 56 | model: 57 | llm_model_path: &tokenizer_path ./assets/lmsys/vicuna-13b-v1.3 58 | num_img_token: &img_len 64 59 | cross_attention_frequency: 4 60 | 61 | dataset_to_ignore_noimage_cond_loss: [laion_en, laion_coco] 62 | 63 | visual_tokenizer_config: 64 | encoder_model_path: ./assets/openai/clip-vit-large-patch14 65 | perceiver_config: 66 | num_queries: 64 67 | hidden_size: 768 68 | encoder_hidden_size: 1024 69 | cross_attention_frequency: 2 70 | num_hidden_layers: 12 71 | num_attention_heads: 12 72 | qk_normalization: True 73 | image_decoder_config: 74 | pretrained_model_name_or_path: './assets/stabilityai/stable-diffusion-2-1-base' 75 | sd_base_seed: 0 76 | sd_use_random_seed: False 77 | perceiver_config: 78 | num_queries: 77 79 | hidden_size: 1024 80 | encoder_hidden_size: 5120 81 | cross_attention_frequency: 1 82 | num_hidden_layers: 1 83 | num_attention_heads: 16 84 | hidden_dropout_prob: 0. 85 | attention_probs_dropout_prob: 0. 86 | 87 | # DATA 88 | 89 | data: 90 | train: 91 | name: random_mix 92 | probs: [1., 1., 1., 2.] 93 | sampling_type: longest 94 | seed: *data_seed 95 | dataset_names: [blip2, laion_en, laion_coco, mmc4] 96 | 97 | datasets: 98 | 99 | - name: laion_wds 100 | data_root: "[THE IMAGE DIRECTORY OF BLIP-2 ANNOTATED DATA]" 101 | annt_root: "[THE ANNOTATION DIRECTORY OF BLIP-2 ANNOTATED DATA]" 102 | tokenizer_path: *tokenizer_path 103 | 104 | per_device_batch_size: 2 105 | input_shards: "[SHARDED FILE NAMES]" # e.g '{0000000..0020000}.txt' 106 | num_samples: "[ESTIMATED TOTAL NUMBER OF TRAINING SAMPLES]" # e.g. 10_000 107 | seed: *data_seed 108 | num_workers: *num_workers 109 | 110 | num_img_token: *img_len 111 | max_num_images_per_seq: 30 112 | 113 | transform: &train_transform 114 | aug_type: 'dual_numpy' 115 | resolution: &image_size 224 116 | resolution2: &image_size_dec 512 117 | 118 | - name: laion_wds 119 | data_root: "[THE IMAGE DIRECTORY OF LAION-EN]" 120 | annt_root: "[THE ANNOTATION DIRECTORY OF LAION-EN]" 121 | tokenizer_path: *tokenizer_path 122 | 123 | per_device_batch_size: 2 124 | input_shards: "[SHARDED FILE NAMES]" # e.g '{0000000..0020000}.txt' 125 | num_samples: "[ESTIMATED TOTAL NUMBER OF TRAINING SAMPLES]" # e.g. 10_000 126 | seed: *data_seed 127 | num_workers: *num_workers 128 | 129 | num_img_token: *img_len 130 | max_num_images_per_seq: 30 131 | 132 | transform: *train_transform 133 | 134 | - name: laion_wds 135 | data_root: "[THE IMAGE DIRECTORY OF LAION-COCO]" 136 | annt_root: "[THE ANNOTATION DIRECTORY OF LAION-COCO]" 137 | tokenizer_path: *tokenizer_path 138 | 139 | per_device_batch_size: 2 140 | input_shards: "[SHARDED FILE NAMES]" # e.g '{0000000..0020000}.txt' 141 | num_samples: "[ESTIMATED TOTAL NUMBER OF TRAINING SAMPLES]" # e.g. 10_000 142 | seed: *data_seed 143 | num_workers: *num_workers 144 | 145 | num_img_token: *img_len 146 | max_num_images_per_seq: 30 147 | 148 | transform: *train_transform 149 | 150 | - name: mmc4_wds 151 | data_root: "[THE IMAGE DIRECTORY OF MMC4]" # e.g. './assets/datasets/mmc4/ai2-jackh-mmc4-gated-public-41423/images/' 152 | annt_root: "[THE ANNOTATION DIRECTORY OF MMC4]" # e.g. './assets/datasets/mmc4/ai2-jackh-mmc4-gated-public-41423/data/' 153 | tokenizer_path: *tokenizer_path 154 | 155 | per_device_batch_size: 4 156 | input_shards: "[SHARDED FILE NAMES]" # 'docs_shard_{0..23099}_v2.jsonl' 157 | num_samples: "[ESTIMATED TOTAL NUMBER OF TRAINING SAMPLES]" # e.g. 10_000 158 | seed: *data_seed 159 | num_workers: *num_workers 160 | 161 | num_img_token: *img_len 162 | max_num_images_per_seq: 15 163 | 164 | transform: *train_transform 165 | 166 | val: 167 | - name: coco_karpathy 168 | data_root: assets/datasets/coco 169 | annt_root: assets/datasets/coco 170 | phase: test 171 | year: 2014 172 | 173 | collator: ImageTextPairCollator 174 | num_img_token: *img_len 175 | tokenizer_path: *tokenizer_path 176 | collate_mode: generate_both 177 | 178 | transform: 179 | aug_type: 'numpy' 180 | resolution: *image_size 181 | 182 | - name: okvqa 183 | data_root: assets/datasets/coco 184 | annt_root: assets/datasets/OK-VQA 185 | phase: val 186 | 187 | collator: VQACollator 188 | num_img_token: *img_len 189 | tokenizer_path: *tokenizer_path 190 | collate_mode: 'generate_vqa' 191 | 192 | transform: 193 | aug_type: 'numpy' 194 | resolution: *image_size 195 | 196 | - name: textvqa 197 | data_root: assets/datasets/textvqa/train_images 198 | annt_root: assets/datasets/textvqa/TextVQA 199 | phase: val 200 | 201 | collator: VQACollator 202 | num_img_token: *img_len 203 | tokenizer_path: *tokenizer_path 204 | collate_mode: 'generate_vqa' 205 | 206 | transform: 207 | aug_type: 'numpy' 208 | resolution: *image_size 209 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import build_dataset 2 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | from .loader import BaseDataset 6 | from functools import cached_property 7 | 8 | 9 | class ADE20kDataset(BaseDataset): 10 | def __init__( 11 | self, 12 | data_root, 13 | annt_root, 14 | transform, 15 | total_length=None, 16 | phase="training", 17 | collate_mode="generate_segm", 18 | add_eos="", 19 | num_img_token=32, 20 | add_soi_token=True, 21 | text_first=False, 22 | context_type="current", 23 | ): 24 | super().__init__() 25 | 26 | self.transform = transform 27 | self.data_root = data_root 28 | self.annt_root = annt_root 29 | 30 | assert phase in ["training", "validation"] 31 | self.phase = phase 32 | 33 | assert collate_mode in ["train", "generate_segm"] 34 | self.collate_mode = collate_mode 35 | self.add_eos = add_eos 36 | self.text_first = text_first 37 | 38 | assert context_type in [ 39 | "multi_modal", 40 | "image_only", 41 | "text_only", 42 | ] 43 | self.context_type = context_type 44 | 45 | self.num_img_token = num_img_token 46 | self.add_soi_token = add_soi_token 47 | 48 | self.image_subseq = "<|image|>" * self.num_img_token 49 | if self.add_soi_token: 50 | self.image_subseq = "<|beginofimage|>" + self.image_subseq 51 | 52 | annt_file = os.path.join(annt_root, f"{phase}.json") 53 | self.annt_file = annt_file 54 | self.load_database() 55 | 56 | if total_length is not None: 57 | self.annts = self.annts[:total_length] 58 | 59 | print(f"length of the dataset is {len(self.annts)}") 60 | 61 | def load_database(self): 62 | with open(self.annt_file, "r") as rf: 63 | self.annts = json.load(rf) 64 | 65 | def __repr__(self) -> str: 66 | return ( 67 | f"ADE20k Dataset phase={self.phase}\n" 68 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 69 | f"transform={self.transform}" 70 | ) 71 | 72 | def __len__(self): 73 | return len(self.annts) 74 | 75 | def _get_image(self, image_id, return_image_path=False): 76 | try: 77 | image_path = os.path.join( 78 | self.data_root, "images", self.phase, f"{image_id}.jpg" 79 | ) 80 | image = self.loader(image_path).convert("RGB") 81 | image = self.transform(image) 82 | except Exception as e: 83 | print(e) 84 | print(image_path) 85 | image = None 86 | 87 | if return_image_path: 88 | return image, image_path 89 | return image 90 | 91 | def _get_annt(self, image_id, return_image_path=False): 92 | try: 93 | image_path = os.path.join( 94 | self.data_root, "annotations_with_color", self.phase, f"{image_id}.png" 95 | ) 96 | image = self.loader(image_path) 97 | image = self.transform(image) 98 | except Exception as e: 99 | print(e) 100 | print(image_path) 101 | image = None 102 | 103 | if return_image_path: 104 | return image, image_path 105 | return image 106 | 107 | def __getitem__(self, index): 108 | item = self.annts[index] 109 | meta = [index] 110 | 111 | images_tensor = [] 112 | text = "" 113 | if self.collate_mode == "train": 114 | assert self.phase == "training" 115 | 116 | annt, _ = self._get_annt(item["image_id"]) 117 | image, image_dec = self._get_image(item["image_id"]) 118 | 119 | if np.random.random() < 0.5: 120 | image = np.ascontiguousarray(image[:, ::-1]) 121 | annt = np.ascontiguousarray(annt[:, ::-1]) 122 | 123 | if self.text_first: 124 | text += f"{item['caption']}.{self.image_subseq}{self.image_subseq}" 125 | else: 126 | text += f"{self.image_subseq}{item['caption']}.{self.image_subseq}" 127 | 128 | images_tensor.append((annt, image_dec)) 129 | images_tensor.append((image, image_dec)) 130 | 131 | else: 132 | assert self.phase != "train" 133 | assert self.collate_mode == "generate_segm" 134 | 135 | annt = self._get_annt(item["image_id"]) 136 | 137 | if self.text_first: 138 | text += f"{item['caption']}.{self.image_subseq}" 139 | else: 140 | text += f"{self.image_subseq}{item['caption']}." 141 | 142 | images_tensor.append(annt) 143 | 144 | # prepare target 145 | 146 | image = self._get_image(item["image_id"]) 147 | text += self.image_subseq 148 | images_tensor.append(image) 149 | 150 | meta.append(item["caption"]) 151 | text = text.strip() 152 | if self.add_eos: 153 | text += self.add_eos 154 | 155 | return dict(text=text, images_tensor=images_tensor, meta=meta) 156 | 157 | @property 158 | def task_prefix(self): 159 | return f"_{self.context_type}" 160 | 161 | def image_id_to_path(self, idx): 162 | image_id = self.annts[idx]["image_id"] 163 | 164 | image_path = os.path.join( 165 | self.data_root, "images", self.phase, f"{image_id}.jpg" 166 | ) 167 | return image_path 168 | 169 | def gt_id_to_path(self, idx): 170 | image_id = self.annts[idx]["image_id"] 171 | 172 | image_path = os.path.join( 173 | self.data_root, "annotations", self.phase, f"{image_id}.png" 174 | ) 175 | return image_path 176 | 177 | @cached_property 178 | def palette(self): 179 | return [ 180 | 0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200, 181 | 3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224, 182 | 5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143, 183 | 255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255, 184 | 6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9, 185 | 92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41, 186 | 10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8, 187 | 0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0, 188 | 163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224, 189 | 0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200, 190 | 200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255, 191 | 163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0, 192 | 255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245, 193 | 255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255, 194 | 255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0, 195 | 122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163, 196 | 255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184, 197 | 255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163, 198 | 0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255, 199 | 0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255, 200 | 20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0, 201 | 255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255, 202 | 255,214,0,25,194,194,102,255,0,92,0,255 203 | ] 204 | 205 | 206 | 207 | if __name__ == "__main__": 208 | from .utils import create_transform 209 | 210 | transform = create_transform( 211 | aug_type="flip", resolution=256, random_crop=False, random_flip=True 212 | ) 213 | 214 | dataset = ADE20kDataset( 215 | data_root="./asset/ade20k/ADEChallengeData2016/", 216 | annt_root="./asset/ade20k/ADEChallengeData2016/", 217 | transform=transform, 218 | phase="training", 219 | collate_mode="generate_images", 220 | num_img_token=32, 221 | add_soi_token=True, 222 | context_type="multi_modal", 223 | ) 224 | print(dataset) 225 | 226 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/ade20k_preparation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499) 3 | # Github source: https://github.com/baaivision/Painter 4 | # Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Xinlong Wang, Wen Wang 7 | # Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases 8 | # --------------------------------------------------------' 9 | 10 | import os 11 | import glob 12 | import argparse 13 | import json 14 | import tqdm 15 | import sys 16 | sys.path.insert(0, "data") 17 | 18 | import numpy as np 19 | from PIL import Image 20 | 21 | 22 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 23 | "copied from https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/mit_semseg/utils.py" 24 | ar = np.asanyarray(ar).flatten() 25 | 26 | optional_indices = return_index or return_inverse 27 | optional_returns = optional_indices or return_counts 28 | 29 | if ar.size == 0: 30 | if not optional_returns: 31 | ret = ar 32 | else: 33 | ret = (ar,) 34 | if return_index: 35 | ret += (np.empty(0, np.bool),) 36 | if return_inverse: 37 | ret += (np.empty(0, np.bool),) 38 | if return_counts: 39 | ret += (np.empty(0, np.intp),) 40 | return ret 41 | if optional_indices: 42 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 43 | aux = ar[perm] 44 | else: 45 | ar.sort() 46 | aux = ar 47 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 48 | 49 | if not optional_returns: 50 | ret = aux[flag] 51 | else: 52 | ret = (aux[flag],) 53 | if return_index: 54 | ret += (perm[flag],) 55 | if return_inverse: 56 | iflag = np.cumsum(flag) - 1 57 | inv_idx = np.empty(ar.shape, dtype=np.intp) 58 | inv_idx[perm] = iflag 59 | ret += (inv_idx,) 60 | if return_counts: 61 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 62 | ret += (np.diff(idx),) 63 | return ret 64 | 65 | 66 | def colorEncode(labelmap, colors, mode='RGB'): 67 | "Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/mit_semseg/utils.py" 68 | labelmap = labelmap.astype('int') 69 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 70 | dtype=np.uint8) 71 | 72 | for label in unique(labelmap): 73 | if label <= 0: 74 | continue 75 | # note the color_index = class_index - 1 76 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 77 | np.tile(np.array(colors[label-1], dtype=np.uint8), (labelmap.shape[0], labelmap.shape[1], 1)) 78 | 79 | if mode == 'BGR': 80 | return labelmap_rgb[:, :, ::-1] 81 | else: 82 | return labelmap_rgb 83 | 84 | 85 | def define_colors_per_location_mean_sep(): 86 | num_locations = 150 87 | num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19 88 | separation_per_channel = 256 // num_sep_per_channel 89 | 90 | color_list = [] 91 | for location in range(num_locations): 92 | num_seq_r = location // num_sep_per_channel ** 2 93 | num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel 94 | num_seq_b = location % num_sep_per_channel 95 | assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \ 96 | and (num_seq_b <= num_sep_per_channel) 97 | 98 | R = 255 - num_seq_r * separation_per_channel 99 | G = 255 - num_seq_g * separation_per_channel 100 | B = 255 - num_seq_b * separation_per_channel 101 | assert (R < 256) and (G < 256) and (B < 256) 102 | assert (R >= 0) and (G >= 0) and (B >= 0) 103 | assert (R, G, B) not in color_list 104 | 105 | color_list.append((R, G, B)) 106 | # print(location, (num_seq_r, num_seq_g, num_seq_b), (R, G, B)) 107 | 108 | return color_list 109 | 110 | 111 | PALETTE = define_colors_per_location_mean_sep() 112 | 113 | 114 | def get_args_parser(): 115 | parser = argparse.ArgumentParser('ADE20k semantic segmentation preparation', add_help=False) 116 | parser.add_argument('--split', type=str, help='dataset split', 117 | choices=['training', 'validation'], required=True) 118 | return parser.parse_args() 119 | 120 | 121 | if __name__ == '__main__': 122 | args = get_args_parser() 123 | 124 | image_dir = os.path.join("./asset/ade20k/ADEChallengeData2016/images", args.split) 125 | segm_dir = os.path.join("./asset/ade20k/ADEChallengeData2016/annotations", args.split) 126 | save_dir = os.path.join("./asset/ade20k/ADEChallengeData2016/annotations_with_color", args.split) 127 | if not os.path.exists(save_dir): 128 | os.makedirs(save_dir) 129 | 130 | color_list = define_colors_per_location_mean_sep() 131 | 132 | segm_path_list = glob.glob(os.path.join(segm_dir, '*.png')) 133 | for segm_path in tqdm.tqdm(segm_path_list): 134 | # check files 135 | file_name = os.path.basename(segm_path) 136 | # in ade20k, images are jpegs, while segms are pngs 137 | image_path = os.path.join(image_dir, file_name.replace('.png', '.jpg')) 138 | assert os.path.isfile(segm_path) 139 | assert os.path.isfile(image_path) 140 | 141 | # paint colors on segm 142 | segm = Image.open(segm_path) 143 | segm_color = colorEncode(labelmap=np.array(segm), colors=color_list).astype(np.uint8) 144 | segm_color = Image.fromarray(segm_color) 145 | segm_color.save(os.path.join(save_dir, file_name)) -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/caption_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class NoCapsDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_file, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode='generate_texts', 17 | add_eos=None, 18 | ) -> None: 19 | super().__init__() 20 | self.collate_mode = collate_mode 21 | self.transform = transform 22 | self.data_root = data_root 23 | self.image_only = image_only 24 | self.annts = self.load_annotations(annt_file) 25 | self.annt_file = annt_file 26 | if self.image_only: 27 | self.dedeup_image() 28 | if total_length is not None: 29 | self.annts = self.annts[:total_length] 30 | self.add_eos = add_eos 31 | print(f"length of the dataset is {len(self.annts)}") 32 | 33 | def load_annotations(self, annt_file): 34 | meta_info = json.load(open(annt_file, "r")) 35 | images = meta_info['images'] 36 | annotations = meta_info['annotations'] 37 | 38 | image_info = {} 39 | for image in images: 40 | image_info[image['id']] = image 41 | 42 | processed_annotations = [] 43 | for ann in annotations: 44 | image_id = ann['image_id'] 45 | file_name = image_info[image_id]['file_name'] 46 | caption = ann['caption'] 47 | 48 | processed_annotations.append({ 49 | 'image': file_name, 50 | 'caption': caption, 51 | 'image_id': image_id, 52 | }) 53 | 54 | return processed_annotations 55 | 56 | def dedeup_image(self): 57 | annts = {} 58 | for annt in self.annts: 59 | image_idx = annt["image_id"] 60 | if image_idx in annts: 61 | continue 62 | annts[image_idx] = annt 63 | self.annts = list(annts.values()) 64 | 65 | def __repr__(self) -> str: 66 | return "Nocaps Dataset" 67 | 68 | def __len__(self): 69 | return len(self.annts) 70 | 71 | def __getitem__(self, index): 72 | item = self.annts[index] 73 | caption = item["caption"] 74 | if isinstance(caption, list): # TODO, random choose one caption from the image 75 | caption = random.choice(caption) 76 | caption = caption.lower() 77 | if self.add_eos is not None: 78 | caption = caption + self.add_eos 79 | image_idx_int = item["image_id"] 80 | image_path = os.path.join(self.data_root, item["image"]) 81 | 82 | try: 83 | image = self.loader(image_path).convert("RGB") 84 | 85 | image = self.transform(image) 86 | except: 87 | print(image_path) 88 | index = random.randint(0, len(self) - 1) 89 | return self.__getitem__(index) 90 | 91 | return image, caption, image_idx_int 92 | 93 | 94 | class Flickr30KDataset(NoCapsDataset): 95 | def __repr__(self) -> str: 96 | return "Flickr30K Dataset" 97 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/clip_itp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import CLIPProcessor 3 | 4 | from .loader import BaseDataset 5 | 6 | 7 | class CLIPImageTextPairDataset(BaseDataset): 8 | def __init__( 9 | self, 10 | image_root, 11 | caption_list, 12 | model_name="openai/clip-vit-large-patch14", 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.model_name = model_name 17 | self.image_root = image_root 18 | self.caption_list = caption_list 19 | 20 | self.clip_processor = CLIPProcessor.from_pretrained(model_name) 21 | 22 | print(f"length of the dataset is {len(self.caption_list)}") 23 | 24 | def __repr__(self) -> str: 25 | return ( 26 | f"CLIPImageTextPair Dataset total_length={len(self)}\n" 27 | f"image_root={self.image_root}\nprocessor={self.clip_processor}" 28 | ) 29 | 30 | def __len__(self): 31 | return len(self.caption_list) 32 | 33 | def __getitem__(self, index): 34 | caption = self.caption_list[str(index)]["caption"] 35 | image_path = os.path.join(self.image_root, f"{index:05d}.png") 36 | 37 | image = self.loader(image_path).convert("RGB") 38 | data = self.clip_processor( 39 | images=image, 40 | text=caption, 41 | return_tensors="pt", 42 | padding="max_length", 43 | max_length=77, 44 | ) 45 | 46 | return data.pixel_values[0], data.input_ids[0], index 47 | 48 | 49 | class CLIPImagePairDataset(BaseDataset): 50 | def __init__( 51 | self, 52 | image_pair_list, 53 | model_name="openai/clip-vit-large-patch14", 54 | ) -> None: 55 | 56 | super().__init__() 57 | 58 | self.model_name = model_name 59 | self.image_pair_list = image_pair_list 60 | 61 | self.clip_processor = CLIPProcessor.from_pretrained(model_name) 62 | 63 | print(f"length of the dataset is {len(self.image_pair_list)}") 64 | 65 | def __repr__(self) -> str: 66 | return ( 67 | f"CLIPImagePairDataset total_length={len(self)}\n" 68 | f"processor={self.clip_processor}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.image_pair_list) 73 | 74 | def __getitem__(self, index): 75 | image_path = self.image_pair_list[index]["image_path"] 76 | image = self.loader(image_path).convert("RGB") 77 | 78 | image = self.clip_processor( 79 | images=image, 80 | text=None, 81 | return_tensors="pt", 82 | ).pixel_values[0] 83 | 84 | image_path_gt = self.image_pair_list[index]["image_gt_path"] 85 | image_gt = self.loader(image_path_gt).convert("RGB") 86 | 87 | image_gt = self.clip_processor( 88 | images=image_gt, 89 | text=None, 90 | return_tensors="pt", 91 | ).pixel_values[0] 92 | 93 | return image, image_gt, index 94 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/image2paragraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class Image2ParagraphDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_root, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode="generate_texts", 17 | phase="train", 18 | add_eos=None, 19 | ) -> None: 20 | super().__init__() 21 | self.collate_mode = collate_mode 22 | self.transform = transform 23 | self.data_root = data_root 24 | self.annt_root = annt_root 25 | self.phase = phase 26 | self.image_only = image_only 27 | 28 | annt_file = os.path.join(annt_root, "annotations", f"paragraphs_coco.json") 29 | with open(annt_file, "r") as rf: 30 | data = json.load(rf) 31 | annts = {d["image_id"]: d for d in data["annotations"]} 32 | 33 | split_file = os.path.join(annt_root, "annotations", f"{phase}_split.json") 34 | with open(split_file, "r") as rf: 35 | split_idxs = set(json.load(rf)) 36 | annts = [v for k, v in annts.items() if k in split_idxs] 37 | 38 | self.annts = annts 39 | self.annt_file = annt_file 40 | if total_length is not None: 41 | self.annts = self.annts[:total_length] 42 | self.add_eos = add_eos 43 | print(f"length of the dataset is {len(self.annts)}") 44 | 45 | def __repr__(self) -> str: 46 | return ( 47 | f"Image2Paragraph Dataset phase={self.phase}\n" 48 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 49 | f"transform={self.transform}" 50 | ) 51 | 52 | def __len__(self): 53 | return len(self.annts) 54 | 55 | def __getitem__(self, index): 56 | item = self.annts[index] 57 | caption = item["caption"] 58 | # caption = caption.lower() 59 | if self.add_eos is not None: 60 | caption = caption + self.add_eos 61 | 62 | image_idx_int = item["image_id"] 63 | image_subpaths = item["url"].split("/")[-2:] 64 | image_path = os.path.join(self.data_root, *image_subpaths) 65 | 66 | try: 67 | image = self.loader(image_path).convert("RGB") 68 | 69 | image = self.transform(image) 70 | except: 71 | print(image_path) 72 | index = random.randint(0, len(self) - 1) 73 | return self.__getitem__(index) 74 | 75 | return image, caption, image_idx_int 76 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/laion_wds.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | import os.path as osp 4 | from typing import Tuple 5 | import json 6 | import numpy as np 7 | import functools 8 | 9 | import webdataset as wds 10 | from webdataset.utils import pytorch_worker_info 11 | 12 | from transformers import LlamaTokenizer 13 | 14 | 15 | from .loader import BaseLoader 16 | from .wds_utils import ( 17 | init_tokenizer, 18 | log_and_continue, 19 | ) 20 | from .mmc4_wds import build_interleaved_dataset 21 | 22 | """ 23 | Iterable Version of LAION, using webdataset for data processing 24 | """ 25 | 26 | Image.MAX_IMAGE_PIXELS = 1000000000 27 | N_CHANNELS = 3 28 | MIN_KB = 10 29 | 30 | 31 | def load_laion_database_nothrow( 32 | src, 33 | annt_root="", 34 | handler=log_and_continue, 35 | client=None, 36 | ): 37 | rank, world_size, worker, num_workers = pytorch_worker_info() 38 | for sample in src: 39 | assert isinstance(sample, dict), sample 40 | assert "url" in sample 41 | annt_fname = sample["url"] 42 | data_path = osp.join(annt_root, annt_fname) 43 | 44 | try: 45 | print( 46 | f"[Rank {rank:02d} Worker {worker:02d}] start load from {data_path}", 47 | force=True, 48 | ) 49 | 50 | with io.BytesIO(client.get(data_path)) as rf: 51 | lines = rf.readlines() 52 | 53 | for i, line in enumerate(lines): 54 | yield (line, f"{annt_fname}-{i}") 55 | 56 | print( 57 | f"[Rank {rank:02d} Worker {worker:02d}] finish load from {data_path}", 58 | force=True, 59 | ) 60 | 61 | except Exception as exn: 62 | import traceback 63 | 64 | traceback.print_stack() 65 | exn.args = exn.args + (data_path,) 66 | if handler(exn, force=True): 67 | continue 68 | else: 69 | break 70 | 71 | 72 | def _smart_join(str_or_list, delim): 73 | if isinstance(str_or_list, str): 74 | return str_or_list 75 | else: 76 | return delim.join(str_or_list) 77 | 78 | 79 | def preprocess_laion_data( 80 | sample: Tuple[str], 81 | data_root="", 82 | transform=None, 83 | base_loader=None, 84 | tokenizer: LlamaTokenizer = None, 85 | num_total_token=2048, 86 | num_img_token=32, 87 | img_first_prob=1.0, 88 | ): 89 | info, meta_info = json.loads(sample[0]), sample[-1] 90 | 91 | image_name = info["image"] 92 | image_path = osp.join(data_root, image_name) 93 | try: 94 | image = base_loader(image_path) 95 | image = image.convert("RGB") 96 | except: 97 | raise ValueError(f"Failed to load Image {image_path}") 98 | 99 | image_tensors = transform(image) 100 | 101 | if isinstance(image_tensors, tuple): 102 | image_tensors, image_tensors_dec = np.expand_dims( 103 | image_tensors[0], axis=0 104 | ), np.expand_dims(image_tensors[1], axis=0) 105 | else: 106 | image_tensors, image_tensors_dec = np.expand_dims(image_tensors, axis=0), None 107 | 108 | img_first = np.random.random() < img_first_prob 109 | caption = _smart_join(info["caption"], " ").lower() 110 | image_subseq = "<|beginofimage|>" + "<|image|>" * num_img_token 111 | 112 | if img_first: 113 | text = image_subseq + caption 114 | else: 115 | text = caption + image_subseq 116 | 117 | text = f"{text}{tokenizer.eos_token}" 118 | tokenizer.padding_side = "right" 119 | text_tensor = tokenizer( 120 | text, 121 | padding="do_not_pad", 122 | return_tensors="np", 123 | return_attention_mask=True, 124 | ) 125 | 126 | text_ids = text_tensor["input_ids"][0] 127 | text_attn_mask = text_tensor["attention_mask"][0] 128 | 129 | if len(text_ids) > num_total_token: 130 | if img_first: 131 | text_ids = text_ids[:num_total_token] 132 | text_attn_mask = text_attn_mask[:num_total_token] 133 | else: 134 | text_ids = np.concatenate( 135 | ( 136 | text_ids[: num_total_token - (num_img_token + 2)], 137 | text_ids[-(num_img_token + 2) :], 138 | ), 139 | axis=0, 140 | ) 141 | text_attn_mask = np.concatenate( 142 | ( 143 | text_attn_mask[: num_total_token - (num_img_token + 2)], 144 | text_attn_mask[-(num_img_token + 2) :], 145 | ), 146 | axis=0, 147 | ) 148 | 149 | data = dict( 150 | image_tensors=image_tensors, 151 | text_ids=text_ids, 152 | text_attn_mask=text_attn_mask, 153 | image_tensors_dec=image_tensors_dec, 154 | ) 155 | 156 | return data 157 | 158 | 159 | def build_laion_webdataset( 160 | annt_root="", 161 | data_root="", 162 | transform=None, 163 | tokenizer_path="", 164 | per_device_batch_size=32, 165 | input_shards="{0000000..0000010}.txt", 166 | num_samples=None, 167 | resampled=False, 168 | floor=False, 169 | seed=42, 170 | epoch=0, 171 | num_workers=12, 172 | num_total_token=2048, 173 | num_img_token=64, 174 | max_num_images_per_seq=-1, 175 | img_first_prob=0.5, 176 | loss_img_weight=None, 177 | loss_txt_weight=None, 178 | truncation_level="sample", 179 | use_few_shot_sample=[2,3,4,5,6,7,8], 180 | use_few_shot_prob=0.25, 181 | ): 182 | base_loader = BaseLoader() 183 | shard_to_sample_fn = functools.partial( 184 | load_laion_database_nothrow, 185 | annt_root=annt_root, 186 | client=base_loader.client, 187 | ) 188 | 189 | tokenizer = init_tokenizer(tokenizer_path) 190 | 191 | preprocess_fn = functools.partial( 192 | preprocess_laion_data, 193 | data_root=data_root, 194 | transform=transform, 195 | base_loader=base_loader, 196 | tokenizer=tokenizer, 197 | num_total_token=num_total_token, 198 | num_img_token=num_img_token, 199 | img_first_prob=img_first_prob, 200 | ) 201 | 202 | dataset = build_interleaved_dataset( 203 | shard_to_sample_fn, 204 | preprocess_fn, 205 | tokenizer, 206 | per_device_batch_size=per_device_batch_size, 207 | input_shards=input_shards, 208 | num_samples=num_samples, 209 | resampled=resampled, 210 | floor=floor, 211 | seed=seed, 212 | epoch=epoch, 213 | num_workers=num_workers, 214 | num_total_token=num_total_token, 215 | num_img_token=num_img_token, 216 | max_num_images_per_seq=max_num_images_per_seq, 217 | loss_img_weight=loss_img_weight, 218 | loss_txt_weight=loss_txt_weight, 219 | truncation_level=truncation_level, 220 | use_few_shot_sample=use_few_shot_sample, 221 | use_few_shot_prob=use_few_shot_prob, 222 | ) 223 | 224 | return dataset 225 | 226 | 227 | if __name__ == "__main__": 228 | from .utils import create_transform 229 | 230 | transform = create_transform( 231 | aug_type="numpy", 232 | resolution=256, 233 | resize=True, 234 | random_crop=False, 235 | random_flip=True, 236 | ) 237 | 238 | dataset = build_laion_webdataset( 239 | annt_root="./assets/laion5b/LaionEn", 240 | data_root="", 241 | transform=transform, 242 | tokenizer_path="./assets/openlm-research/open_llama_3b_v2", 243 | per_device_batch_size=32, 244 | input_shards="{0000000..0010336}.txt", 245 | num_samples=2_600_000, 246 | resampled=False, 247 | floor=False, 248 | seed=42, 249 | num_workers=1, 250 | num_img_token=32, 251 | max_num_images_per_seq=-1, 252 | num_total_token=2048, 253 | img_first_prob=0.5, 254 | ) 255 | 256 | assert isinstance(dataset, wds.DataPipeline) 257 | print(dataset) 258 | 259 | dataloader = wds.WebLoader( 260 | dataset, 261 | batch_size=None, 262 | shuffle=False, 263 | num_workers=0, 264 | persistent_workers=False, 265 | ) 266 | print(dataloader) 267 | 268 | for i, data in enumerate(dataloader): 269 | images_tensors, text_ids, text_attn_mask, num_images = ( 270 | data["image_tensors"], 271 | data["text_ids"], 272 | data["attention_mask"], 273 | data["num_image_per_seq"], 274 | ) 275 | texts = dataset.tokenizer.batch_decode(text_ids) 276 | 277 | print(images_tensors.shape) 278 | print(text_ids) 279 | print(num_images) 280 | print(data["meta"]) 281 | 282 | break 283 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/lncoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | from collections import Counter 6 | 7 | from .loader import BaseDataset 8 | 9 | 10 | class LNCOCODataset(BaseDataset): 11 | def __init__( 12 | self, 13 | data_root, 14 | annt_root, 15 | transform, 16 | image_only=False, 17 | total_length=None, 18 | collate_mode="generate_images", 19 | phase="val", 20 | add_eos=None, 21 | ) -> None: 22 | super().__init__() 23 | assert phase == "val" and collate_mode in ["generate_images"] 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.image_only = image_only 30 | 31 | annt_file = os.path.join(annt_root, "coco_val_captions.jsonl") 32 | with open(annt_file, "r") as rf: 33 | data = rf.readlines() 34 | self.annts = [json.loads(s) for s in data] 35 | self.annt_file = annt_file 36 | if self.image_only: 37 | self.dedeup_image() 38 | if total_length is not None: 39 | if total_length <= len(self.annts): 40 | self.annts = self.annts[:total_length] 41 | else: 42 | # over sampling 43 | cnter_image = Counter([a["image_id"] for a in self.annts]) 44 | annts_weight = [1./cnter_image[a["image_id"]] for a in self.annts] 45 | annts_weight = [w / sum(annts_weight) for w in annts_weight] 46 | annts_n = np.random.choice(self.annts, total_length - len(self.annts), p=annts_weight) 47 | self.annts += list(annts_n) 48 | self.add_eos = add_eos 49 | print(f"length of the dataset is {len(self.annts)}") 50 | 51 | def dedeup_image(self): 52 | annts = {} 53 | for annt in self.annts: 54 | image_idx = annt["image_id"] 55 | if image_idx in annts: 56 | continue 57 | annts[image_idx] = annt 58 | self.annts = list(annts.values()) 59 | 60 | def image_id_to_path(self, image_id): 61 | # coco-2017 62 | return os.path.join(self.data_root, "val2017", f"{image_id:012d}.jpg") 63 | 64 | def __repr__(self) -> str: 65 | return ( 66 | f"LNCOCO Dataset phase={self.phase}\n" 67 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 68 | f"transform={self.transform}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.annts) 73 | 74 | def __getitem__(self, index): 75 | item = self.annts[index] 76 | caption = item["caption"] 77 | # caption = caption.lower() 78 | if self.add_eos is not None: 79 | caption = caption + self.add_eos 80 | 81 | image_idx_int = int(item["image_id"]) 82 | image_path = os.path.join(self.data_root, "val2017", f"{image_idx_int:012d}.jpg") 83 | 84 | try: 85 | image = self.loader(image_path).convert("RGB") 86 | 87 | image = self.transform(image) 88 | except: 89 | print(image_path) 90 | index = random.randint(0, len(self) - 1) 91 | return self.__getitem__(index) 92 | 93 | return image, caption, image_idx_int 94 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/loader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | import cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset, IterableDataset 6 | 7 | import logging 8 | import os 9 | 10 | LOG_LOADER = os.environ.get("LOG_LOADER", False) 11 | 12 | 13 | def pil_loader(img_str): 14 | buff = io.BytesIO(img_str) 15 | return Image.open(buff) 16 | 17 | 18 | def cv2_loader(img_bytes): 19 | # assert(img_bytes is not None) 20 | img_mem_view = memoryview(img_bytes) 21 | img_array = np.frombuffer(img_mem_view, np.uint8) 22 | imgcv2 = cv2.imdecode(img_array, cv2.IMREAD_COLOR) 23 | imgcv2 = cv2.cvtColor(imgcv2, cv2.COLOR_BGR2RGB) 24 | return Image.fromarray(imgcv2) 25 | 26 | 27 | class LocalClient(): 28 | def __init__(self, **kwargs) -> None: 29 | pass 30 | 31 | def get(self, url): 32 | with open(url, "rb") as rf: 33 | data = rf.read() 34 | return data 35 | 36 | 37 | class BaseLoader(object): 38 | def __init__(self): 39 | self.client = LocalClient() 40 | 41 | def __call__(self, fn): 42 | try: 43 | if self.client is not None: 44 | img_value_str = self.client.get(fn) 45 | img = pil_loader(img_value_str) 46 | else: 47 | img = Image.open(fn) 48 | except: 49 | try: 50 | img = cv2_loader(img_value_str) 51 | except Exception as exn: 52 | exn.args = exn.args + (fn,) 53 | if LOG_LOADER: 54 | logging.warning(f"Handling BaseLoader image reading error ({repr(exn)}). Ignoring.") 55 | # print('Read image failed ({})'.format(fn)) 56 | return None 57 | else: 58 | return img 59 | else: 60 | return img 61 | 62 | 63 | class BaseDataset(Dataset): 64 | def __init__(self) -> None: 65 | super().__init__() 66 | self.loader = BaseLoader() 67 | self.client = self.loader.client 68 | 69 | def __getitem__(self, index): 70 | raise NotImplementedError 71 | 72 | 73 | class IterableBaseDataset(IterableDataset): 74 | def __init__(self) -> None: 75 | super().__init__() 76 | self.loader = BaseLoader() 77 | self.client = self.loader.client 78 | 79 | def __iter__(self): 80 | raise NotImplementedError 81 | 82 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/mix_dataset.py: -------------------------------------------------------------------------------- 1 | """Classes for mixing samples from multiple sources.""" 2 | 3 | import math 4 | from itertools import permutations 5 | import numpy as np 6 | from typing import List 7 | import torch 8 | from torch.utils.data import IterableDataset 9 | 10 | from .wds_utils import WdsDataset, pytorch_worker_info 11 | 12 | 13 | def random_samples(datasets, probs=None, sampling_type="sum", seed=0, fix_seed=False, dataset_names=None): 14 | sources = [iter(d) for d in datasets] 15 | if probs is None: 16 | probs = [1] * len(sources) 17 | else: 18 | probs = list(probs) 19 | 20 | generator = torch.Generator() 21 | if not fix_seed: 22 | rank, world_size, worker, num_workers = pytorch_worker_info() 23 | seed += rank * num_workers + worker 24 | generator.manual_seed(seed) 25 | 26 | is_source_finished = [0] * len(sources) 27 | while len(sources) > 0 and sum(is_source_finished) < len(datasets): 28 | cum = (np.array(probs) / np.sum(probs)).cumsum() 29 | r = torch.rand(1, generator=generator).item() 30 | i = np.searchsorted(cum, r) 31 | try: 32 | data = next(sources[i]) 33 | 34 | if dataset_names is not None: 35 | assert "meta" in data and isinstance(data["meta"], dict) and len(dataset_names) == len(datasets) 36 | data["meta"]["dataset_name"] = dataset_names[i] 37 | 38 | yield data 39 | except StopIteration: 40 | if sampling_type == "sum": 41 | del sources[i] 42 | del probs[i] 43 | elif sampling_type == "longest": 44 | sources[i] = iter(datasets[i]) 45 | is_source_finished[i] = 1 46 | else: 47 | break 48 | 49 | 50 | class RandomMixWdsDataset(IterableDataset): 51 | def __init__( 52 | self, 53 | datasets: List[WdsDataset], 54 | probs=None, 55 | sampling_type="sum", 56 | seed=0, 57 | fix_sampling_ratio=False, 58 | dataset_names=None, 59 | ): 60 | self.dataset_names = dataset_names 61 | self.datasets = datasets 62 | for dataset in datasets: 63 | try: 64 | dataset_len = len(dataset) 65 | except: 66 | dataset_len = -1 67 | 68 | dataset_name = getattr(dataset, 'dataset_name', dataset.__class__.__name__) 69 | print(f'{dataset_name}: {dataset_len}') 70 | 71 | self.fix_sampling_ratio = fix_sampling_ratio 72 | if self.fix_sampling_ratio: 73 | assert ( 74 | probs is None 75 | ), "do not support setting different probs for each dataset when fixing sampling ratio." 76 | self._permute_dataset_by_rank() 77 | 78 | if probs is None: 79 | probs = [1] * len(datasets) 80 | else: 81 | probs = list(probs) 82 | 83 | self.probs = probs 84 | assert sampling_type in ["longest", "shortest", "sum"] 85 | self.sampling_type = sampling_type 86 | self.seed = seed 87 | 88 | def _permute_dataset_by_rank(self): 89 | permute_list = list(permutations(range(len(self.datasets)))) 90 | rank, world_size, worker, num_workers = pytorch_worker_info() 91 | idx_list = permute_list[rank % len(permute_list)] 92 | self.datasets = [self.datasets[i] for i in idx_list] 93 | 94 | def __iter__(self): 95 | """Return an iterator over the sources.""" 96 | return random_samples( 97 | self.datasets, 98 | self.probs, 99 | self.sampling_type, 100 | self.seed, 101 | fix_seed=self.fix_sampling_ratio, 102 | dataset_names=self.dataset_names, 103 | ) 104 | 105 | def set_epoch(self, epoch): 106 | for d in self.datasets: 107 | d.set_epoch(epoch) 108 | 109 | def set_tokenizer(self, tokenizer): 110 | for d in self.datasets: 111 | d.set_tokenizer(tokenizer) 112 | 113 | @property 114 | def epoch(self): 115 | return self.datasets[0].epoch 116 | 117 | @property 118 | def tokenizer(self): 119 | return self.datasets[0].tokenizer 120 | 121 | def __repr__(self) -> str: 122 | repr_str = f"RandomMixDataset: probs={self.probs}; sampling_type={self.sampling_type}\n" 123 | for d in self.datasets: 124 | repr_str += repr(d) + "\n" 125 | return repr_str 126 | 127 | def __len__(self): 128 | try: 129 | lens_dataset = np.array([len(d) for d in self.datasets]) 130 | except: 131 | # raise NotImplementedError 132 | return None 133 | 134 | if self.sampling_type == "sum": 135 | return sum(lens_dataset) 136 | elif self.sampling_type == "longest": 137 | i = np.argmax(lens_dataset) 138 | return math.ceil(lens_dataset[i] / self.probs[i] * sum(self.probs)) 139 | else: 140 | i = np.argmin(lens_dataset) 141 | return math.ceil(lens_dataset[i] / self.probs[i] * sum(self.probs)) 142 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/mscoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | 6 | from .loader import BaseDataset 7 | 8 | 9 | class CocoCaptionDataset(BaseDataset): 10 | def __init__( 11 | self, 12 | data_root, 13 | annt_root, 14 | transform, 15 | image_only=False, 16 | total_length=None, 17 | collate_mode="generate_images", 18 | shuffle=False, 19 | rerank_by_clip=False, 20 | phase="train", 21 | year="2014", 22 | ) -> None: 23 | super().__init__() 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.year = year 30 | self.image_only = image_only 31 | self.rerank_by_clip = rerank_by_clip 32 | 33 | annt_file = os.path.join( 34 | annt_root, "annotations", f"captions_{phase}{year}.json" 35 | ) 36 | self.annt_file = annt_file 37 | self.annts = json.load(open(annt_file, "r"))["annotations"] 38 | if self.image_only: 39 | self.dedeup_image() 40 | if shuffle: 41 | np.random.shuffle(self.annts) 42 | if total_length is not None: 43 | self.annts = self.annts[:total_length] 44 | print(f"length of the dataset is {len(self.annts)}") 45 | 46 | def dedeup_image(self): 47 | annts = {} 48 | for annt in self.annts: 49 | image_idx = str(annt["image_id"]).zfill(12) 50 | if image_idx in annts: 51 | continue 52 | annts[image_idx] = annt 53 | self.annts = list(annts.values()) 54 | 55 | def image_id_to_path(self, image_id): 56 | # coco-2014 57 | image_idx = str(image_id).zfill(12) 58 | image_name = f"COCO_{self.phase}{self.year}_{image_idx}.jpg" 59 | image_path = os.path.join( 60 | self.data_root, f"{self.phase}{self.year}", image_name 61 | ) 62 | return image_path 63 | 64 | def __repr__(self) -> str: 65 | return ( 66 | f"MSCOCO-Caption Dataset year={self.year} phase={self.phase}\n" 67 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 68 | f"transform={self.transform}" 69 | ) 70 | 71 | def __len__(self): 72 | return len(self.annts) 73 | 74 | def __getitem__(self, index): 75 | item = self.annts[index] 76 | caption = item["caption"].lower() 77 | 78 | image_idx = str(item["image_id"]).zfill(12) 79 | image_name = f"COCO_{self.phase}{self.year}_{image_idx}.jpg" 80 | image_path = os.path.join( 81 | self.data_root, f"{self.phase}{self.year}", image_name 82 | ) 83 | try: 84 | image = self.loader(image_path).convert("RGB") 85 | 86 | image = self.transform(image) 87 | except: 88 | print(image_path) 89 | index = random.randint(0, len(self) - 1) 90 | return self.__getitem__(index) 91 | 92 | return image, caption, item["image_id"] 93 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/mscoco_karpathy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class CocoCaptionKarpathyDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_root, 13 | transform, 14 | image_only=False, 15 | total_length=None, 16 | collate_mode="generate_texts", 17 | phase="train", 18 | year="2014", 19 | add_eos=None, 20 | use_1st_sentence_only=True, 21 | rerank_by_clip=False, 22 | ) -> None: 23 | super().__init__() 24 | self.collate_mode = collate_mode 25 | self.transform = transform 26 | self.data_root = data_root 27 | self.annt_root = annt_root 28 | self.phase = phase 29 | self.year = year 30 | self.image_only = image_only 31 | annt_file = os.path.join( 32 | annt_root, "annotations", f"coco_karpathy_{phase}.json" 33 | ) 34 | self.annts = json.load(open(annt_file, "r")) 35 | self.annt_file = annt_file 36 | if self.image_only: 37 | self.dedeup_image() 38 | if total_length is not None: 39 | self.annts = self.annts[:total_length] 40 | self.add_eos = add_eos 41 | self.use_1st_sentence_only = use_1st_sentence_only 42 | self.rerank_by_clip = rerank_by_clip 43 | print(f"length of the dataset is {len(self.annts)}") 44 | 45 | def dedeup_image(self): 46 | annts = {} 47 | for annt in self.annts: 48 | image_idx = annt["image"].split("_")[-1][ 49 | :-4 50 | ] # 'val2014/COCO_val2014_000000391895.jpg' 51 | if image_idx in annts: 52 | continue 53 | annts[image_idx] = annt 54 | self.annts = list(annts.values()) 55 | 56 | def image_id_to_path(self, image_id): 57 | phase = "val" if self.phase == "test" else self.phase 58 | # coco-2014 59 | image_idx = str(image_id).zfill(12) 60 | image_name = f"COCO_{phase}{self.year}_{image_idx}.jpg" 61 | image_path = os.path.join( 62 | self.data_root, f"{phase}{self.year}", image_name 63 | ) 64 | return image_path 65 | 66 | def __repr__(self) -> str: 67 | return ( 68 | f"MSCOCO-Caption Karpathy Dataset year={self.year} phase={self.phase}\n" 69 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 70 | f"transform={self.transform}" 71 | ) 72 | 73 | def __len__(self): 74 | return len(self.annts) 75 | 76 | def __getitem__(self, index): 77 | item = self.annts[index] 78 | caption = item["caption"] 79 | if isinstance(caption, list): 80 | caption = random.choice(caption) 81 | caption = caption.lower() 82 | if self.add_eos is not None: 83 | caption = caption + self.add_eos 84 | image_idx_int = int(item["image"].split("_")[-1][:-4]) 85 | image_name = item["image"] 86 | image_path = os.path.join(self.data_root, f"{image_name}") 87 | 88 | try: 89 | image = self.loader(image_path).convert("RGB") 90 | 91 | image = self.transform(image) 92 | except: 93 | print(image_path) 94 | index = random.randint(0, len(self) - 1) 95 | return self.__getitem__(index) 96 | 97 | return image, caption, image_idx_int 98 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/sft_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 7 | 8 | from .loader import BaseDataset 9 | 10 | 11 | class LLaVADataset(BaseDataset): 12 | def __init__( 13 | self, 14 | annt_root=[], 15 | data_root=[], 16 | transform=None, 17 | ): 18 | super().__init__() 19 | self.ann_path = [annt_root] if isinstance(annt_root, str) else annt_root 20 | self.data_root = [data_root] if isinstance(data_root, str) else data_root 21 | self.transform = transform 22 | 23 | self.ann = [] 24 | print("Formatting inputs...Skip in lazy mode") 25 | for index, p in enumerate(self.ann_path): 26 | if p.endswith('json'): 27 | with open(p, 'r') as file: 28 | data = json.load(file) 29 | for item in data: 30 | try: 31 | item['image'] = os.path.join(self.data_root[index], item['image']) 32 | self.ann.append(item) 33 | except: 34 | pass 35 | elif p.endswith('.jsonl'): 36 | for line in open(p, 'r'): 37 | data = json.loads(line) 38 | try: 39 | data['image'] = os.path.join(self.data_root[index], data['image']) 40 | self.ann.append(data) 41 | except: 42 | pass 43 | 44 | # split multi-round dialogues to single-round dialogue 45 | max_conv_num = 2 # 1 round 46 | print(f"data length before split: {len(self.ann)}") 47 | new_ann = [] 48 | for item in self.ann: 49 | conversations = item["conversations"] 50 | conversations = [conversations[i:i + max_conv_num] for i in range(0, len(conversations), max_conv_num)] 51 | for conv in conversations: 52 | new_item = item.copy() 53 | if "" not in conv[0]['value']: 54 | conv[0]['value'] = "\n" + conv[0]['value'] 55 | new_item["conversations"] = conv 56 | new_ann.append(new_item) 57 | self.ann = new_ann 58 | print(f"data length after split: {len(self.ann)}") 59 | 60 | def __getitem__(self, index): 61 | while True: 62 | try: 63 | data = self.ann[index] 64 | 65 | assert len(data['conversations']) == 2 66 | 67 | query = data['conversations'][0]['value'].replace('\n', '') 68 | query = query.replace('\n', '') 69 | query = query.replace('', '') 70 | 71 | image_id = data['id'] 72 | image = self.loader(data['image']).convert('RGB') 73 | label = data['conversations'][1]['value'] 74 | break 75 | except Exception as e: 76 | print(e) 77 | print('Error loading data:', data['image']) 78 | index = random.randint(0, len(self.ann) - 1) 79 | 80 | return self.transform(image), query, label, image_id 81 | 82 | def __len__(self): 83 | return len(self.ann) 84 | 85 | 86 | class WeightedConcatDataset(ConcatDataset): 87 | def __init__(self, datasets, weights): 88 | super().__init__(datasets) 89 | self.weights = torch.DoubleTensor(weights) 90 | self.total_size = sum(len(d) for d in datasets) 91 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) 92 | 93 | def __iter__(self): 94 | return iter(self.sampler) 95 | 96 | def __len__(self): 97 | return self.total_size 98 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/visdial_dense.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import json 3 | import random 4 | 5 | from .loader import BaseDataset 6 | from .wds_utils import init_tokenizer 7 | 8 | 9 | class VisDialDenseDataset(BaseDataset): 10 | def __init__( 11 | self, 12 | data_root, 13 | annt_root, 14 | transform, 15 | tokenizer_path, 16 | total_length=None, 17 | num_img_token=32, 18 | collate_mode='generate_scores', 19 | phase="val", 20 | ) -> None: 21 | ''' 22 | VisDial dataset only for NDCG evaluation 23 | ''' 24 | super().__init__() 25 | 26 | assert phase == 'val' 27 | 28 | self.phase = phase 29 | self.transform = transform 30 | self.data_root = data_root 31 | self.annt_root = annt_root 32 | self.tokenizer = init_tokenizer(tokenizer_path) 33 | self.num_img_token = num_img_token 34 | self.collate_mode = collate_mode 35 | 36 | dialog_json_path = osp.join(self.annt_root, 'visdial_1.0_val.json') 37 | with open(dialog_json_path, 'r') as rf: 38 | data = json.load(rf)["data"] 39 | 40 | self.dialogs = data["dialogs"] 41 | self.questions = data["questions"] 42 | self.answers = data["answers"] 43 | 44 | dense_annt_path = osp.join(self.annt_root, 'visdial_1.0_val_dense_annotations.json') 45 | with open(dense_annt_path, 'r') as rf: 46 | data_dense = json.load(rf) 47 | self.dense_annt = {d["image_id"]:d for d in data_dense} 48 | 49 | if total_length is not None: 50 | self.dialogs = self.dialogs[:total_length] 51 | print(f"length of the dataset is {len(self.dialogs)}") 52 | 53 | def __repr__(self) -> str: 54 | return ( 55 | f"VisDial Dataset phase={self.phase}\n" 56 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 57 | f"transform={self.transform}" 58 | ) 59 | 60 | def __len__(self): 61 | return len(self.dialogs) 62 | 63 | def __getitem__(self, index): 64 | 65 | item = self.dialogs[index] 66 | 67 | image_id = item["image_id"] 68 | image_path = osp.join(self.data_root, "VisualDialog_val2018", f"VisualDialog_val2018_{image_id:012d}.jpg") 69 | 70 | try: 71 | image = self.loader(image_path).convert("RGB") 72 | 73 | image = self.transform(image) 74 | except: 75 | print(image_path) 76 | index = random.randint(0, len(self) - 1) 77 | return self.__getitem__(index) 78 | 79 | image_prompt = "<|beginofimage|>" + "<|image|>" * self.num_img_token 80 | text = f"{image_prompt} caption: {item['caption']}. " 81 | dense_annt = self.dense_annt[image_id] 82 | round_idx = dense_annt["round_id"] - 1 83 | dialog = item["dialog"] 84 | for rnd in range(round_idx-1): 85 | question = self.questions[dialog[rnd]["question"]] 86 | answer = self.answers[dialog[rnd]["answer"]] 87 | text += f"question: {question}? answer: {answer}. " 88 | 89 | question = self.questions[dialog[round_idx]["question"]] 90 | text += f"question: {question}? answer:" 91 | 92 | options = dialog[round_idx]["answer_options"] 93 | options = [self.answers[i] for i in options] 94 | # gt_relevance = dense_annt["gt_relevance"] 95 | 96 | # assert len(gt_relevance) == len(options) 97 | 98 | text_tensor = self.tokenizer( 99 | [text], 100 | truncation=False, 101 | padding=False, 102 | return_tensors="pt", 103 | return_attention_mask=True, 104 | ) 105 | text_ids = text_tensor.input_ids[0] 106 | attn_mask = text_tensor.attention_mask[0] 107 | 108 | options_tensor = self.tokenizer( 109 | options, 110 | truncation=False, 111 | padding=True, 112 | return_tensors="pt", 113 | return_attention_mask=True, 114 | ) 115 | options_ids = options_tensor.input_ids 116 | options_attn_mask = options_tensor.attention_mask 117 | 118 | return dict( 119 | image_id=image_id, 120 | image_tensor=image, 121 | # context=text, 122 | # options=options, 123 | text_ids=text_ids, 124 | attn_mask=attn_mask, 125 | options_ids=options_ids[:,1:], # no 126 | options_attn_mask=options_attn_mask[:,1:], 127 | # gt_relevance=gt_relevance, 128 | ) 129 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/vist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | from .loader import BaseDataset 6 | 7 | 8 | class VISTDataset(BaseDataset): 9 | def __init__( 10 | self, 11 | data_root, 12 | annt_root, 13 | transform, 14 | total_length=None, 15 | phase="train", 16 | 17 | collate_mode="generate_texts", 18 | add_eos="", 19 | num_img_token=32, 20 | img_first_prob=0.0, 21 | add_soi_token=True, 22 | round_range="last", 23 | context_type="current", 24 | ): 25 | super().__init__() 26 | 27 | self.transform = transform 28 | self.data_root = data_root 29 | self.annt_root = annt_root 30 | 31 | assert phase in ["train", "val", "test"] 32 | self.phase = phase 33 | 34 | assert collate_mode in ["train", "generate_texts", "generate_images"] 35 | self.collate_mode = collate_mode 36 | self.add_eos = add_eos 37 | 38 | assert round_range in ["last", "all"] 39 | self.round_range = round_range 40 | 41 | assert context_type in [ 42 | "multi_modal", 43 | "image_only", 44 | "text_only", 45 | "current", 46 | ] 47 | self.context_type = context_type 48 | 49 | self.num_img_token = num_img_token 50 | self.img_first_prob = img_first_prob 51 | self.add_soi_token = add_soi_token 52 | 53 | self.image_subseq = "<|image|>" * self.num_img_token 54 | if self.add_soi_token: 55 | self.image_subseq = "<|beginofimage|>" + self.image_subseq 56 | 57 | annt_file = os.path.join( 58 | annt_root, "annotations", f"{phase}_formatted_filtered.json" 59 | ) 60 | self.annt_file = annt_file 61 | self.load_database() 62 | 63 | if total_length is not None: 64 | self.annts = self.annts[:total_length] 65 | 66 | print(f"length of the dataset is {len(self.annts)}") 67 | 68 | def load_database(self): 69 | with open(self.annt_file, "r") as rf: 70 | annts = json.load(rf)["annotations"] 71 | 72 | data = [] 73 | for k, v in annts.items(): 74 | v.sort(key=lambda x: x["sequence_index"]) 75 | data.append(dict(story_id=k, story=v)) 76 | data.sort(key=lambda x: x["story_id"]) 77 | 78 | if self.round_range == "all": 79 | assert self.phase != "train" 80 | data_n = [] 81 | for d in data: 82 | for i in range(1, len(d["story"])): 83 | d_n = dict(story_id=f"{d['story_id']}_{i}", story=d["story"][:i]) 84 | data_n.append(d_n) 85 | data = data_n 86 | 87 | self.annts = data 88 | 89 | def __repr__(self) -> str: 90 | return ( 91 | f"VIST Dataset phase={self.phase}\n" 92 | f"annotation_root={self.annt_root} data_root={self.data_root}\n" 93 | f"transform={self.transform}" 94 | ) 95 | 96 | def __len__(self): 97 | return len(self.annts) 98 | 99 | def _get_image(self, image_id, return_image_path=False): 100 | try: 101 | image_path = os.path.join( 102 | self.data_root, "images", f"{self.phase}_images", f"{image_id}.png" 103 | ) 104 | image = self.loader(image_path).convert("RGB") 105 | image = self.transform(image) 106 | except Exception as e: 107 | print(e) 108 | print(image_path) 109 | image = None 110 | 111 | if return_image_path: 112 | return image, image_path 113 | return image 114 | 115 | def __getitem__(self, index): 116 | item = self.annts[index]["story"] 117 | meta = [self.annts[index]["story_id"]] 118 | 119 | images_tensor = [] 120 | text = "" 121 | if self.collate_mode == "train": 122 | assert self.phase == "train" 123 | 124 | # no target image / text 125 | for i in range(len(item)): 126 | turn = item[i] 127 | image = self._get_image(turn["image_id"]) 128 | if np.random.random() < self.img_first_prob: 129 | _text = f"{self.image_subseq}{turn['caption']} " 130 | else: 131 | _text = f"{turn['caption']}{self.image_subseq} " 132 | 133 | text += _text 134 | images_tensor.append(image) 135 | 136 | else: 137 | assert self.phase != "train" 138 | 139 | # prepare history context 140 | if self.context_type == "multi_modal": 141 | for i in range(len(item) - 1): 142 | turn = item[i] 143 | image = self._get_image(turn["image_id"]) 144 | if np.random.random() < self.img_first_prob: 145 | _text = f"{self.image_subseq}{turn['caption']} " 146 | else: 147 | _text = f"{turn['caption']}{self.image_subseq} " 148 | 149 | text += _text 150 | images_tensor.append(image) 151 | 152 | elif self.context_type == "image_only": 153 | for i in range(len(item) - 1): 154 | turn = item[i] 155 | image = self._get_image(turn["image_id"]) 156 | text += self.image_subseq 157 | images_tensor.append(image) 158 | 159 | elif self.context_type == "text_only": 160 | for i in range(len(item) - 1): 161 | turn = item[i] 162 | text += f"{turn['caption']} " 163 | 164 | # prepare target 165 | if self.collate_mode == "generate_texts": 166 | turn = item[-1] 167 | 168 | if self.context_type != "text_only": 169 | image = self._get_image(turn["image_id"]) 170 | text += self.image_subseq 171 | images_tensor.append(image) 172 | 173 | meta.append(turn["caption"]) 174 | 175 | elif self.collate_mode == "generate_images": 176 | turn = item[-1] 177 | if self.context_type != "image_only": 178 | text += turn["caption"] 179 | 180 | image, image_path = self._get_image( 181 | turn["image_id"], return_image_path=True 182 | ) 183 | text += self.image_subseq 184 | images_tensor.append(image) 185 | 186 | meta.append(image_path) 187 | 188 | text = text.strip() 189 | if self.add_eos: 190 | text += self.add_eos 191 | 192 | return dict(text=text, images_tensor=images_tensor, meta=meta) 193 | 194 | @property 195 | def task_prefix(self): 196 | return f"_{self.context_type}_{self.round_range}" 197 | -------------------------------------------------------------------------------- /mm_interleaved/custom_datasets/vqa_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from .loader import BaseDataset 4 | 5 | 6 | class VQABaseDataset(BaseDataset): 7 | def __init__( 8 | self, 9 | data_root, 10 | annt_file, 11 | transform=None, 12 | total_length=None, 13 | phase='train', 14 | collate_mode='generate_vqa', 15 | add_eos=None, 16 | ): 17 | super().__init__() 18 | self.collate_mode = collate_mode 19 | self.transform = transform 20 | self.data_root = data_root 21 | self.annt_file = annt_file 22 | self.phase = phase 23 | if total_length is not None: 24 | self.annts = self.annts[:total_length] 25 | self.add_eos = add_eos 26 | self.ann = self.load_annotations() 27 | print(f"length of the {self.__class__.__name__} is {len(self.ann)}") 28 | 29 | def load_annotations(self): 30 | raise NotImplementedError 31 | 32 | def __getitem__(self, index): 33 | ann = self.ann[index] 34 | image = self.loader(os.path.join(self.data_root, ann['file_name'])).convert('RGB') 35 | image = self.transform(image) if self.transform is not None else image 36 | question = ann['question'] 37 | answer = ann['answer'] 38 | question_id = ann.get('question_id', -1) 39 | 40 | return image, question, answer, question_id 41 | 42 | def __len__(self): 43 | return len(self.ann) 44 | 45 | @property 46 | def data_shape(self): 47 | return 4, 32, 32 48 | 49 | 50 | class VQAV2Dataset(VQABaseDataset): 51 | def __init__( 52 | self, 53 | data_root='./assets/coco/images', 54 | annt_root='./assets/VQAv2', 55 | phase='train', 56 | ann_name_format='v2_mscoco_{split}2014_annotations.json', 57 | question_name_format='v2_OpenEnded_mscoco_{split}2014_questions.json', 58 | **kwargs, 59 | ): 60 | self.question_file = os.path.join(annt_root, question_name_format.format(split=phase)) 61 | 62 | data_root = os.path.join(data_root, f'{phase}2014') 63 | annt_file = os.path.join(annt_root, ann_name_format.format(split=phase)) 64 | super().__init__(data_root=data_root, annt_file=annt_file, phase=phase, **kwargs) 65 | 66 | def load_annotations(self): 67 | answers_info = json.load(open(self.annt_file))['annotations'] 68 | questions_info = json.load(open(self.question_file))['questions'] 69 | 70 | annotations = {} 71 | for info in answers_info: 72 | image_id = info['image_id'] 73 | question_id = info['question_id'] 74 | answer = info['multiple_choice_answer'] if 'multiple_choice_answer' in info else info['answers'][0]['answer'] 75 | 76 | assert question_id not in annotations 77 | annotations[question_id] = { 78 | 'image_id': image_id, 79 | 'question_id': question_id, 80 | 'answer': answer, 81 | 'file_name': f'COCO_{self.phase}2014_{image_id:012d}.jpg', 82 | } 83 | 84 | for info in questions_info: 85 | image_id = info['image_id'] 86 | question_id = info['question_id'] 87 | question = info['question'] 88 | 89 | assert annotations[question_id]['image_id'] == image_id 90 | annotations[question_id]['question'] = question 91 | 92 | return list(annotations.values()) 93 | 94 | 95 | class OKVQADataset(VQAV2Dataset): 96 | def __init__( 97 | self, 98 | annt_root='./assets/OK-VQA', 99 | ann_name_format='mscoco_{split}2014_annotations.json', 100 | question_name_format='OpenEnded_mscoco_{split}2014_questions.json', 101 | **kwargs, 102 | ): 103 | super().__init__(annt_root=annt_root, ann_name_format=ann_name_format, question_name_format=question_name_format, **kwargs) 104 | 105 | 106 | class VizWizVQADataset(VQABaseDataset): 107 | def __init__( 108 | self, 109 | data_root='./assets/VizWiz', 110 | annt_root='./assets/VizWiz-VQA', 111 | phase='train', 112 | batch_size=4, 113 | **kwargs, 114 | ): 115 | data_root = os.path.join(data_root, phase) 116 | annt_file = os.path.join(annt_root, f'{phase}.json') 117 | super().__init__(data_root=data_root, annt_file=annt_file, phase=phase, **kwargs) 118 | self.batch_size = batch_size 119 | 120 | def load_annotations(self): 121 | meta_info = json.load(open(self.annt_file)) 122 | 123 | annotations = [] 124 | for ann in meta_info: 125 | annotations.append({ 126 | 'question_id': int(ann['image'].split('_')[-1].split('.')[0]), 127 | 'file_name': ann['image'], 128 | 'question': ann['question'], 129 | 'answer': ann['answers'][0]['answer'], 130 | }) 131 | 132 | return annotations 133 | 134 | class TextVQADataset(VQABaseDataset): 135 | def __init__( 136 | self, 137 | data_root='./assets/TextVQA/train_images', 138 | annt_root='./assets/TextVQA', 139 | phase='train', 140 | ann_name_format='textvqa_{split}_annotations.json', 141 | question_name_format='textvqa_{split}_questions.json', 142 | **kwargs, 143 | ): 144 | self.question_file = os.path.join(annt_root, question_name_format.format(split=phase)) 145 | 146 | annt_file = os.path.join(annt_root, ann_name_format.format(split=phase)) 147 | super().__init__(data_root=data_root, annt_file=annt_file, phase=phase, **kwargs) 148 | 149 | def load_annotations(self): 150 | answers_info = json.load(open(self.annt_file))['annotations'] 151 | questions_info = json.load(open(self.question_file))['questions'] 152 | 153 | annotations = {} 154 | for info in answers_info: 155 | image_id = info['image_id'] 156 | question_id = info['question_id'] 157 | answer = info['multiple_choice_answer'] if 'multiple_choice_answer' in info else info['answers'][0]['answer'] 158 | 159 | assert question_id not in annotations 160 | annotations[question_id] = { 161 | 'image_id': image_id, 162 | 'question_id': question_id, 163 | 'answer': answer, 164 | } 165 | 166 | for info in questions_info: 167 | image = info['image'] 168 | image_id = info['image_id'] 169 | question_id = info['question_id'] 170 | question = info['question'] 171 | 172 | assert annotations[question_id]['image_id'] == image_id 173 | annotations[question_id]['question'] = question 174 | annotations[question_id]['file_name'] = image 175 | 176 | return list(annotations.values()) 177 | -------------------------------------------------------------------------------- /mm_interleaved/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mm_interleaved import MMInterleaved 2 | -------------------------------------------------------------------------------- /mm_interleaved/models/decoders/decoder_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from .perceiver import PerceiverResampler 6 | from .sd import StableDiffusion 7 | 8 | 9 | class ImageDecoder(nn.Module): 10 | def __init__( 11 | self, 12 | pretrained_model_name_or_path="", 13 | uncond_prob=0.1, 14 | seq_len=77, 15 | embed_dim=1024, 16 | image_size=512, 17 | mmfs_input_channel=1024, 18 | mmfs_feat_levels=4, 19 | vae_encode_mini_bs=32, 20 | sd_base_seed=0, 21 | sd_use_random_seed=False, 22 | sd_use_vae_gradient_checkpointing=True, 23 | sd_use_unet_gradient_checkpointing=True, 24 | perceiver_config=None, 25 | ): 26 | super().__init__() 27 | self.uncond_prob = uncond_prob 28 | 29 | self.perceiver_resampler = PerceiverResampler(**perceiver_config) 30 | self.decoder = StableDiffusion( 31 | pretrained_model_name_or_path, 32 | image_size=image_size, 33 | use_vae_gradient_checkpointing=sd_use_vae_gradient_checkpointing, 34 | use_unet_gradient_checkpointing=sd_use_unet_gradient_checkpointing, 35 | vae_encode_mini_bs=vae_encode_mini_bs, 36 | base_seed=sd_base_seed, 37 | use_random_seed=sd_use_random_seed, 38 | mmfs_input_channel=mmfs_input_channel, 39 | mmfs_feat_levels=mmfs_feat_levels, 40 | ) 41 | 42 | if self.uncond_prob > 0: 43 | self.neg_prompt_embeds = nn.Parameter( 44 | torch.zeros(1, seq_len, embed_dim) 45 | ) 46 | nn.init.normal_(self.neg_prompt_embeds, std=0.02) 47 | assert self.neg_prompt_embeds.shape[1] == seq_len 48 | neg_prompt_embeds = self.decoder.get_negative_prompt_embeds( 49 | uncond_tokens=[""], 50 | device="cuda", 51 | dtype=self.neg_prompt_embeds.dtype, 52 | ) 53 | neg_prompt_embeds = neg_prompt_embeds.to( 54 | device=self.neg_prompt_embeds.device 55 | ) 56 | self.neg_prompt_embeds.data.copy_(neg_prompt_embeds) 57 | 58 | def print_parameters_stats(self, prefix=""): 59 | for name, module in self.named_children(): 60 | print( 61 | f"# {prefix}{name} Total parameters: {sum(p.numel() for p in module.parameters()) / 1e6:.2f}M" 62 | ) 63 | print( 64 | f"# {prefix}{name} Trainable parameters: {sum(p.numel() for p in module.parameters() if p.requires_grad) / 1e6:.2f}M" 65 | ) 66 | if hasattr(module, "print_parameters_stats"): 67 | module.print_parameters_stats(prefix=f"{prefix}{name}.") 68 | 69 | def forward( 70 | self, 71 | image_tensors, 72 | context_features, 73 | context_attention_mask=None, 74 | image_loss_mask=None, 75 | mmfs_features=None, 76 | mmfs_mask=None, 77 | **kwargs, 78 | ): 79 | """ 80 | image_tensors: [B_I, 3, H, W] 81 | context_features: [B_I, L, D] 82 | """ 83 | assert image_tensors.shape[0] == context_features.shape[0] 84 | if context_attention_mask is not None: 85 | assert torch.all( 86 | context_attention_mask.sum(dim=1) > 0 87 | ), f"{context_attention_mask.sum(dim=1)=}" 88 | 89 | context_features = self.perceiver_resampler( 90 | encoder_hidden_states=context_features, 91 | encoder_attention_mask=context_attention_mask, 92 | return_dict=False, 93 | )[0] 94 | 95 | if self.uncond_prob > 0.0: 96 | uncond_mask = ( 97 | torch.rand_like(context_features[:, :1, :1]) < self.uncond_prob 98 | ) 99 | neg_prompt_embeds = self.neg_prompt_embeds 100 | context_features = torch.where( 101 | uncond_mask, neg_prompt_embeds, context_features 102 | ) 103 | 104 | sd_loss = self.decoder( 105 | image_tensors, 106 | context_features, 107 | mmfs_features=mmfs_features, 108 | mmfs_mask=mmfs_mask, 109 | **kwargs, 110 | ) 111 | assert context_attention_mask is not None 112 | is_cond_image = context_attention_mask.sum(dim=1) > 2 # [, ] 113 | is_cond_image = rearrange(is_cond_image, "b -> b 1 1 1") 114 | sd_loss = sd_loss * is_cond_image 115 | if image_loss_mask is not None: 116 | image_loss_mask = rearrange(image_loss_mask, "b -> b 1 1 1") 117 | sd_loss = sd_loss * image_loss_mask 118 | sd_loss = sd_loss.mean() 119 | 120 | return sd_loss 121 | 122 | @torch.no_grad() 123 | def generate_images( 124 | self, 125 | context_features, 126 | context_attention_mask=None, 127 | mmfs_features=None, 128 | mmfs_mask=None, 129 | **kwargs, 130 | ): 131 | output = {} 132 | 133 | context_features = self.perceiver_resampler( 134 | encoder_hidden_states=context_features, 135 | encoder_attention_mask=context_attention_mask, 136 | return_dict=False, 137 | )[0] 138 | num_inference_steps = kwargs.pop("num_inference_steps", 30) 139 | guidance_scale = kwargs.pop("guidance_scale", 7.5) 140 | num_validation_images = kwargs.pop("num_validation_images", 1) 141 | 142 | negative_prompt_embeds = self.neg_prompt_embeds.expand_as( 143 | context_features 144 | ) 145 | images = self.decoder.generate_images( 146 | text_embeds=context_features, 147 | negative_prompt_embeds=negative_prompt_embeds, 148 | num_validation_images=num_validation_images, 149 | num_inference_steps=num_inference_steps, 150 | guidance_scale=guidance_scale, 151 | mmfs_features=mmfs_features, 152 | mmfs_mask=mmfs_mask, 153 | ) 154 | 155 | output["image"] = images 156 | return output 157 | -------------------------------------------------------------------------------- /mm_interleaved/models/decoders/decoder_text.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint 7 | 8 | from transformers import LlamaForCausalLM 9 | from transformers.models.llama.modeling_llama import ( 10 | _make_causal_mask, 11 | _expand_mask, 12 | ) 13 | from transformers.utils import logging, ModelOutput 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | @dataclass 20 | class TextDecoderOutputWithPast(ModelOutput): 21 | logits: torch.FloatTensor = None 22 | last_hidden_state: torch.FloatTensor = None 23 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 24 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 25 | attentions: Optional[Tuple[torch.FloatTensor]] = None 26 | 27 | 28 | class TextDecoder(nn.Module): 29 | def __init__( 30 | self, 31 | config=None, 32 | txt_vocab_size=32002, 33 | orig_txt_vocab_size=-1, 34 | is_freeze=True, 35 | gradient_checkpointing=True, 36 | ): 37 | super().__init__() 38 | self.config = config 39 | self.is_freeze = is_freeze 40 | self.orig_txt_vocab_size = orig_txt_vocab_size 41 | assert orig_txt_vocab_size > 0 and orig_txt_vocab_size < txt_vocab_size 42 | 43 | self.head = nn.Linear( 44 | config.hidden_size, txt_vocab_size, bias=True 45 | ) 46 | self.head_new = nn.Linear(config.hidden_size, txt_vocab_size - orig_txt_vocab_size, bias=True) 47 | 48 | self.gradient_checkpointing = gradient_checkpointing 49 | 50 | self.requires_grad_(not is_freeze) 51 | self.head_new.requires_grad_(True) 52 | 53 | def init_from_llm(self, llm_model: LlamaForCausalLM, orig_txt_vocab_size=-1): 54 | # initialize nn.Linear and nn.LayerNorm 55 | self.apply(self._init_weights) 56 | 57 | # init head weight from llm_model 58 | self.head.weight.data.copy_( 59 | llm_model.lm_head.weight.data[: self.head.weight.data.shape[0]] 60 | ) 61 | if orig_txt_vocab_size > 0: 62 | if self.is_freeze: 63 | nn.init.constant_(self.head.weight[orig_txt_vocab_size:], 0.0) 64 | else: 65 | mean = llm_model.lm_head.weight[:orig_txt_vocab_size].mean() 66 | std = llm_model.lm_head.weight[:orig_txt_vocab_size].std() 67 | nn.init.trunc_normal_( 68 | self.head.weight[orig_txt_vocab_size:], mean=mean, std=std 69 | ) 70 | 71 | # init head bias from llm_model 72 | if llm_model.lm_head.bias is not None: 73 | self.head.bias.data.copy_(llm_model.lm_head.bias.data) 74 | if orig_txt_vocab_size > 0: 75 | if self.is_freeze: 76 | nn.init.constant_(self.head.bias[orig_txt_vocab_size:], -100.0) 77 | else: 78 | mean = llm_model.lm_head.bias[:orig_txt_vocab_size].mean() 79 | std = llm_model.lm_head.bias[:orig_txt_vocab_size].std() 80 | nn.init.trunc_normal_( 81 | self.head.bias[orig_txt_vocab_size:], mean=mean, std=std 82 | ) 83 | elif self.head.bias is not None: 84 | nn.init.constant_(self.head.bias, 0) 85 | if self.is_freeze: 86 | nn.init.constant_(self.head.bias[orig_txt_vocab_size:], -100.0) 87 | self.head.bias.requires_grad_(False) 88 | 89 | nn.init.constant_(self.head_new.weight.data, 0.0) 90 | bias_min = -5. 91 | nn.init.constant_(self.head_new.bias, 100.0 + bias_min) 92 | 93 | def _init_weights(self, m): 94 | if isinstance(m, nn.Linear): 95 | # we use xavier_uniform following official JAX ViT: 96 | torch.nn.init.xavier_uniform_(m.weight) 97 | if isinstance(m, nn.Linear) and m.bias is not None: 98 | nn.init.constant_(m.bias, 0) 99 | elif isinstance(m, nn.LayerNorm): 100 | nn.init.constant_(m.bias, 0) 101 | nn.init.constant_(m.weight, 1.0) 102 | 103 | def print_parameters_stats(self, prefix=""): 104 | for name, module in self.named_children(): 105 | print( 106 | f"# {prefix}{name} Total parameters: {sum(p.numel() for p in module.parameters()) / 1e6:.2f}M" 107 | ) 108 | print( 109 | f"# {prefix}{name} Trainable parameters: {sum(p.numel() for p in module.parameters() if p.requires_grad) / 1e6:.2f}M" 110 | ) 111 | 112 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 113 | def _prepare_decoder_attention_mask( 114 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 115 | ): 116 | # create causal mask 117 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 118 | combined_attention_mask = None 119 | if input_shape[-1] > 1: 120 | combined_attention_mask = _make_causal_mask( 121 | input_shape, 122 | inputs_embeds.dtype, 123 | device=inputs_embeds.device, 124 | past_key_values_length=past_key_values_length, 125 | ) 126 | 127 | if attention_mask is not None: 128 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 129 | expanded_attn_mask = _expand_mask( 130 | attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] 131 | ).to(inputs_embeds.device) 132 | combined_attention_mask = ( 133 | expanded_attn_mask 134 | if combined_attention_mask is None 135 | else expanded_attn_mask + combined_attention_mask 136 | ) 137 | 138 | return combined_attention_mask 139 | 140 | def forward( 141 | self, 142 | inputs_embeds: Optional[torch.FloatTensor], 143 | attention_mask: Optional[torch.Tensor] = None, 144 | position_ids: Optional[torch.LongTensor] = None, 145 | past_key_values: Optional[List[torch.FloatTensor]] = None, 146 | use_cache: Optional[bool] = None, 147 | output_attentions: Optional[bool] = None, 148 | output_hidden_states: Optional[bool] = None, 149 | return_dict: Optional[bool] = None, 150 | **kwargs, 151 | ) -> Union[Tuple, TextDecoderOutputWithPast]: 152 | hidden_states = inputs_embeds 153 | outputs = TextDecoderOutputWithPast() if return_dict else () 154 | 155 | logits = self.head(hidden_states) 156 | logits_new = self.head_new(hidden_states) 157 | logits[..., self.orig_txt_vocab_size:] = logits[..., self.orig_txt_vocab_size:] + logits_new 158 | 159 | if not return_dict: 160 | return (logits, *outputs) 161 | 162 | outputs.logits = logits 163 | return outputs 164 | -------------------------------------------------------------------------------- /mm_interleaved/models/decoders/perceiver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import Blip2QFormerModel, Blip2QFormerConfig 5 | 6 | 7 | class PerceiverResampler(nn.Module): 8 | def __init__( 9 | self, 10 | num_queries=32, 11 | hidden_size=768, 12 | qk_normalization=False, 13 | gradient_checkpointing=True, 14 | **kwargs 15 | ) -> None: 16 | super().__init__() 17 | 18 | config = Blip2QFormerConfig(hidden_size=hidden_size, **kwargs) 19 | config.qk_normalization = qk_normalization 20 | self.blip2qformer = Blip2QFormerModel(config) 21 | 22 | self.queries = nn.Parameter(torch.zeros(1, num_queries, hidden_size)) 23 | self.queries.data.normal_(0, config.initializer_range) 24 | if gradient_checkpointing: 25 | self.blip2qformer.gradient_checkpointing_enable() 26 | 27 | def forward(self, **kwargs): 28 | query_embeds = kwargs.pop("query_embeds", self.queries) 29 | 30 | return self.blip2qformer(query_embeds=query_embeds, **kwargs) 31 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/visual_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops import rearrange 5 | 6 | from .vit_adapter import clip_vit_adapter_hf 7 | from ..decoders.perceiver import PerceiverResampler 8 | from ..utils.pos_embed import get_abs_pos, get_2d_sincos_pos_embed 9 | 10 | 11 | class VisualTokenizer(nn.Module): 12 | def __init__( 13 | self, 14 | encoder_model_path="./assets/openai/clip-vit-large-patch14", 15 | perceiver_config=None, 16 | llm_hidden_size=5120, 17 | clip_normalize=True, 18 | grid_size=16, 19 | ) -> None: 20 | super().__init__() 21 | self.clip_normalize = clip_normalize 22 | self.encoder = clip_vit_adapter_hf(model_path=encoder_model_path) 23 | encoder_hidden_size = perceiver_config.encoder_hidden_size 24 | 25 | self.pos_proj = nn.Linear(encoder_hidden_size, encoder_hidden_size) 26 | self.pos_ln = nn.LayerNorm(encoder_hidden_size, eps=1e-6) 27 | self.pos_embed = nn.Parameter( 28 | torch.from_numpy( 29 | get_2d_sincos_pos_embed(encoder_hidden_size, grid_size, cls_token=True) 30 | ).float() 31 | ).requires_grad_(False) 32 | 33 | self.perceiver_resampler = PerceiverResampler(**perceiver_config) 34 | self.length = perceiver_config.num_queries 35 | self.post_ln = nn.LayerNorm(encoder_hidden_size, eps=1e-6) 36 | self.proj = nn.Linear(perceiver_config.hidden_size, llm_hidden_size) 37 | 38 | self.initialize_weights() 39 | 40 | if self.clip_normalize: 41 | # normalize image 42 | CLIP_MEAN, CLIP_STD = [0.48145466, 0.4578275, 0.40821073], [ 43 | 0.26862954, 44 | 0.26130258, 45 | 0.27577711, 46 | ] 47 | mean, std = torch.tensor(CLIP_MEAN), torch.tensor(CLIP_STD) 48 | mean, std = rearrange(mean, "c -> 1 c 1 1"), rearrange(std, "c -> 1 c 1 1") 49 | self.register_buffer("clip_mean", mean) 50 | self.register_buffer("clip_std", std) 51 | 52 | def print_parameters_stats(self, prefix=""): 53 | for name, module in self.named_children(): 54 | print( 55 | f"# {prefix}{name} Total parameters: {sum(p.numel() for p in module.parameters()) / 1e6:.2f}M" 56 | ) 57 | print( 58 | f"# {prefix}{name} Trainable parameters: {sum(p.numel() for p in module.parameters() if p.requires_grad) / 1e6:.2f}M" 59 | ) 60 | 61 | def initialize_weights(self): 62 | nn.init.normal_(self.proj.weight, std=1.0e-3) 63 | nn.init.constant_(self.proj.bias, 0.0) 64 | 65 | def forward(self, image): 66 | if self.clip_normalize: 67 | # normalize image 68 | image = (image - self.clip_mean) / self.clip_std 69 | 70 | model_output = self.encoder(image) 71 | image_embed = model_output.last_hidden_state 72 | multiscale_features = model_output.hidden_states 73 | 74 | multiscale_features_n = [] 75 | for ms_feat in multiscale_features: 76 | pos_embed = get_abs_pos( 77 | self.pos_embed[1:], ms_feat.size(2) * ms_feat.size(3) 78 | ) 79 | pos_embed = rearrange(pos_embed, "(h w) c -> c h w", h=ms_feat.size(2)) 80 | ms_feat = ms_feat + pos_embed 81 | multiscale_features_n.append(ms_feat) 82 | multiscale_features = multiscale_features_n 83 | 84 | pos_embed = get_abs_pos(self.pos_embed, image_embed.size(1)) 85 | qformer_inputs = self.pos_ln(self.pos_proj(image_embed)) 86 | qformer_inputs = qformer_inputs + pos_embed 87 | image_embed = image_embed + pos_embed 88 | 89 | qformer_inputs = self.post_ln(qformer_inputs) 90 | vis_embed = self.perceiver_resampler( 91 | encoder_hidden_states=qformer_inputs, 92 | encoder_attention_mask=None, 93 | return_dict=False, 94 | )[0] 95 | vis_embed = self.proj(vis_embed) 96 | 97 | output = dict(vis_embed=vis_embed) 98 | output["image_embeds"] = image_embed[:, 1:, :] # remove cls token 99 | output["multiscale_features"] = multiscale_features 100 | 101 | return output 102 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_vit_hf import CLIPVisionTransformer, CLIPVisionModel 2 | from .vit_adapter_hf import CLIPVisionTransformerAdapter, CLIPVisionAdapterModel 3 | from .vit_adapter_hf import clip_vit_adapter_hf 4 | 5 | __all__ = ["CLIPVisionTransformer", "CLIPVisionModel", 'clip_vit_adapter_hf', 6 | "CLIPVisionTransformerAdapter", "CLIPVisionAdapterModel"] 7 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 10 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from torch.cuda.amp import custom_bwd, custom_fwd 18 | 19 | try: 20 | import MultiScaleDeformableAttention as MSDA 21 | except: 22 | print("MultiScaleDeformableAttention is not installed") 23 | 24 | 25 | class MSDeformAttnFunction(Function): 26 | @staticmethod 27 | @custom_fwd 28 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 29 | im2col_step): 30 | ctx.im2col_step = im2col_step 31 | output = MSDA.ms_deform_attn_forward( 32 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 33 | ctx.im2col_step) 34 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, 35 | attention_weights) 36 | return output 37 | 38 | @staticmethod 39 | @once_differentiable 40 | @custom_bwd 41 | def backward(ctx, grad_output): 42 | grad_output = grad_output.contiguous() 43 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 44 | grad_value, grad_sampling_loc, grad_attn_weight = \ 45 | MSDA.ms_deform_attn_backward( 46 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, 47 | grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import * 10 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import math 12 | import warnings 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | from torch.nn.init import constant_, xavier_uniform_ 18 | 19 | from ..functions import MSDeformAttnFunction 20 | 21 | 22 | def _is_power_of_2(n): 23 | if (not isinstance(n, int)) or (n < 0): 24 | raise ValueError('invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n))) 25 | return (n & (n - 1) == 0) and n != 0 26 | 27 | 28 | class MSDeformAttn(nn.Module): 29 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): 30 | """Multi-Scale Deformable Attention Module. 31 | 32 | :param d_model hidden dimension 33 | :param n_levels number of feature levels 34 | :param n_heads number of attention heads 35 | :param n_points number of sampling points per attention head per feature level 36 | """ 37 | super().__init__() 38 | if d_model % n_heads != 0: 39 | raise ValueError('d_model must be divisible by n_heads, ' 40 | 'but got {} and {}'.format(d_model, n_heads)) 41 | _d_per_head = d_model // n_heads 42 | # you'd better set _d_per_head to a power of 2 43 | # which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn( 46 | "You'd better set d_model in MSDeformAttn to make " 47 | 'the dimension of each attention head a power of 2 ' 48 | 'which is more efficient in our CUDA implementation.') 49 | 50 | self.im2col_step = 1 51 | 52 | self.d_model = d_model 53 | self.n_levels = n_levels 54 | self.n_heads = n_heads 55 | self.n_points = n_points 56 | self.ratio = ratio 57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 59 | self.value_proj = nn.Linear(d_model, int(d_model * ratio)) 60 | self.output_proj = nn.Linear(int(d_model * ratio), d_model) 61 | 62 | self._reset_parameters() 63 | 64 | def _reset_parameters(self): 65 | constant_(self.sampling_offsets.weight.data, 0.) 66 | thetas = torch.arange( 67 | self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 68 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 69 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view( 70 | self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | 74 | with torch.no_grad(): 75 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 76 | constant_(self.attention_weights.weight.data, 0.) 77 | constant_(self.attention_weights.bias.data, 0.) 78 | xavier_uniform_(self.value_proj.weight.data) 79 | constant_(self.value_proj.bias.data, 0.) 80 | xavier_uniform_(self.output_proj.weight.data) 81 | constant_(self.output_proj.bias.data, 0.) 82 | 83 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, 84 | input_level_start_index, input_padding_mask=None): 85 | """ 86 | :param query (N, Length_{query}, C) 87 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 88 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 89 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 90 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 91 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 92 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 93 | 94 | :return output (N, Length_{query}, C) 95 | """ 96 | 97 | N, Len_q, _ = query.shape 98 | N, Len_in, _ = input_flatten.shape 99 | assert (input_spatial_shapes[:, 0] * 100 | input_spatial_shapes[:, 1]).sum() == Len_in 101 | 102 | value = self.value_proj(input_flatten) 103 | if input_padding_mask is not None: 104 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 105 | 106 | value = value.view(N, Len_in, self.n_heads, 107 | int(self.ratio * self.d_model) // self.n_heads) 108 | sampling_offsets = self.sampling_offsets(query).view( 109 | N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 110 | attention_weights = self.attention_weights(query).view( 111 | N, Len_q, self.n_heads, self.n_levels * self.n_points) 112 | attention_weights = F.softmax(attention_weights, -1). \ 113 | view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 114 | 115 | if reference_points.shape[-1] == 2: 116 | offset_normalizer = torch.stack( 117 | [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 118 | sampling_locations = reference_points[:, :, None, :, None, :] \ 119 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 120 | elif reference_points.shape[-1] == 4: 121 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 122 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 123 | else: 124 | raise ValueError( 125 | 'Last dim of reference_points must be 2 or 4, but get {} instead.' 126 | .format(reference_points.shape[-1])) 127 | sampling_locations = sampling_locations.to(value.dtype) 128 | output = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, 129 | sampling_locations, attention_weights, self.im2col_step) 130 | output = self.output_proj(output) 131 | return output 132 | -------------------------------------------------------------------------------- /mm_interleaved/models/encoders/vit_adapter/xattn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | 7 | # the xformers lib allows less memory, faster training and inference 8 | try: 9 | import xformers 10 | import xformers.ops 11 | 12 | XFORMERS_IS_AVAILBLE = True 13 | # print('xformers enabled') 14 | except: 15 | XFORMERS_IS_AVAILBLE = False 16 | print("xformers disabled") 17 | 18 | from transformers import CLIPVisionModel, CLIPTextModel 19 | 20 | 21 | class CLIPXAttention(nn.Module): 22 | """Memory Efficient Attention layer for CLIP, support full & causal attn mask""" 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | self.embed_dim = config.hidden_size 28 | self.num_heads = config.num_attention_heads 29 | self.head_dim = self.embed_dim // self.num_heads 30 | if self.head_dim * self.num_heads != self.embed_dim: 31 | raise ValueError( 32 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 33 | f" {self.num_heads})." 34 | ) 35 | self.scale = self.head_dim**-0.5 36 | self.dropout = config.attention_dropout 37 | 38 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) 39 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) 40 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) 41 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) 42 | 43 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 44 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() 45 | # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 46 | 47 | def forward( 48 | self, 49 | hidden_states: torch.Tensor, 50 | attention_mask: Optional[torch.Tensor] = None, 51 | causal_attention_mask: Optional[torch.Tensor] = None, 52 | output_attentions: Optional[bool] = False, 53 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 54 | """Input shape: Batch x Time x Channel""" 55 | bsz, tgt_len, embed_dim = hidden_states.size() 56 | 57 | query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) 58 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 59 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 60 | 61 | # use xformers here 62 | assert (self.dropout == 0.0) and (attention_mask is None) 63 | attention_mask = ( 64 | xformers.ops.LowerTriangularMask() 65 | if causal_attention_mask is not None 66 | else None 67 | ) 68 | # q, k, v = query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2) 69 | q, k, v = query_states, key_states, value_states 70 | attn_output = xformers.ops.memory_efficient_attention( 71 | q, k, v, attn_bias=attention_mask 72 | ) 73 | attn_weights_reshaped = None 74 | 75 | # # get query proj 76 | # query_states = self.q_proj(hidden_states) * self.scale 77 | # key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 78 | # value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 79 | 80 | # proj_shape = (bsz * self.num_heads, -1, self.head_dim) 81 | # query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 82 | # key_states = key_states.view(*proj_shape) 83 | # value_states = value_states.view(*proj_shape) 84 | 85 | # src_len = key_states.size(1) 86 | # attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 87 | 88 | # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 89 | # raise ValueError( 90 | # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 91 | # f" {attn_weights.size()}" 92 | # ) 93 | 94 | # # apply the causal_attention_mask first 95 | # if causal_attention_mask is not None: 96 | # if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): 97 | # raise ValueError( 98 | # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" 99 | # f" {causal_attention_mask.size()}" 100 | # ) 101 | # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask 102 | # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 103 | 104 | # if attention_mask is not None: 105 | # if attention_mask.size() != (bsz, 1, tgt_len, src_len): 106 | # raise ValueError( 107 | # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 108 | # ) 109 | # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 110 | # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 111 | 112 | # attn_weights = nn.functional.softmax(attn_weights, dim=-1) 113 | 114 | # if output_attentions: 115 | # # this operation is a bit akward, but it's required to 116 | # # make sure that attn_weights keeps its gradient. 117 | # # In order to do so, attn_weights have to reshaped 118 | # # twice and have to be reused in the following 119 | # attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 120 | # attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 121 | # else: 122 | # attn_weights_reshaped = None 123 | 124 | # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 125 | 126 | # attn_output = torch.bmm(attn_probs, value_states) 127 | 128 | # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 129 | # raise ValueError( 130 | # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 131 | # f" {attn_output.size()}" 132 | # ) 133 | 134 | # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 135 | # attn_output = attn_output.transpose(1, 2) 136 | 137 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 138 | 139 | attn_output = self.out_proj(attn_output) 140 | 141 | return attn_output, attn_weights_reshaped 142 | 143 | 144 | def convert_clip_visual_attn(model: CLIPVisionModel): 145 | for layer in model.vision_model.encoder.layers: 146 | attn_o = layer.self_attn 147 | attn_x = CLIPXAttention(config=attn_o.config) 148 | for module_name in ["q_proj", "v_proj", "k_proj", "out_proj"]: 149 | module_o: nn.Linear = getattr(attn_o, module_name) 150 | module_x: nn.Linear = getattr(attn_x, module_name) 151 | module_x.weight.data.copy_(module_o.weight.data) 152 | module_x.bias.data.copy_(module_o.bias.data) 153 | layer.self_attn = attn_x 154 | del attn_o 155 | print("convert clip visual self_attn to memory efficient mode successfully") 156 | 157 | 158 | def convert_clip_text_attn(model: CLIPTextModel): 159 | for layer in model.text_model.encoder.layers: 160 | attn_o = layer.self_attn 161 | attn_x = CLIPXAttention(config=attn_o.config) 162 | for module_name in ["q_proj", "v_proj", "k_proj", "out_proj"]: 163 | module_o: nn.Linear = getattr(attn_o, module_name) 164 | module_x: nn.Linear = getattr(attn_x, module_name) 165 | module_x.weight.data.copy_(module_o.weight.data) 166 | module_x.bias.data.copy_(module_o.bias.data) 167 | layer.self_attn = attn_x 168 | del attn_o 169 | print("convert clip text self_attn to memory efficient mode successfully") 170 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/monkey_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_flash_attn_train_monkey_patch import replace_llama_attn_with_flash_attn 2 | from .blip2_qknorm_monkey_patch import replace_blip2_attn_with_qknorm_attn 3 | from .beam_search_monkey_patch import replace_beam_search 4 | from .sd_pipeline_monkey_patch import replace_stable_diffusion_pipeline_call 5 | from .sd_unet_forward_monkey_patch import replace_stable_diffusion_unet_forward 6 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/monkey_patch/blip2_qknorm_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import transformers 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | 7 | 8 | class Blip2QFormerMultiHeadAttention(nn.Module): 9 | def __init__(self, config, is_cross_attention=False): 10 | super().__init__() 11 | self.config = config 12 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 13 | raise ValueError( 14 | "The hidden size (%d) is not a multiple of the number of attention heads (%d)" 15 | % (config.hidden_size, config.num_attention_heads) 16 | ) 17 | 18 | self.num_attention_heads = config.num_attention_heads 19 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 20 | self.all_head_size = self.num_attention_heads * self.attention_head_size 21 | 22 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 23 | if is_cross_attention: 24 | self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) 25 | self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) 26 | else: 27 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 28 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 29 | 30 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 31 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 32 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 33 | self.max_position_embeddings = config.max_position_embeddings 34 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 35 | self.save_attention = False 36 | 37 | ##################### add qk_norm ##################### 38 | dim = self.attention_head_size 39 | self.q_norm = nn.LayerNorm(dim, eps=config.layer_norm_eps) if config.qk_normalization else nn.Identity() 40 | self.k_norm = nn.LayerNorm(dim, eps=config.layer_norm_eps) if config.qk_normalization else nn.Identity() 41 | print('init Blip2QFormerMultiHeadAttention with qk_norm') 42 | ##################### add qk_norm ##################### 43 | 44 | def save_attn_gradients(self, attn_gradients): 45 | self.attn_gradients = attn_gradients 46 | 47 | def get_attn_gradients(self): 48 | return self.attn_gradients 49 | 50 | def save_attention_map(self, attention_map): 51 | self.attention_map = attention_map 52 | 53 | def get_attention_map(self): 54 | return self.attention_map 55 | 56 | def transpose_for_scores(self, x): 57 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 58 | x = x.view(*new_x_shape) 59 | return x.permute(0, 2, 1, 3) 60 | 61 | def forward( 62 | self, 63 | hidden_states, 64 | attention_mask=None, 65 | head_mask=None, 66 | encoder_hidden_states=None, 67 | encoder_attention_mask=None, 68 | past_key_value=None, 69 | output_attentions=False, 70 | ): 71 | # If this is instantiated as a cross-attention module, the keys 72 | # and values come from an encoder; the attention mask needs to be 73 | # such that the encoder's padding tokens are not attended to. 74 | is_cross_attention = encoder_hidden_states is not None 75 | 76 | if is_cross_attention: 77 | ##################### add qk_norm ##################### 78 | key_layer = self.k_norm(self.transpose_for_scores(self.key(encoder_hidden_states))) 79 | ##################### add qk_norm ##################### 80 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 81 | attention_mask = encoder_attention_mask 82 | elif past_key_value is not None: 83 | ##################### add qk_norm ##################### 84 | key_layer = self.k_norm(self.transpose_for_scores(self.key(hidden_states))) 85 | ##################### add qk_norm ##################### 86 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 87 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 88 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 89 | else: 90 | ##################### add qk_norm ##################### 91 | key_layer = self.k_norm(self.transpose_for_scores(self.key(hidden_states))) 92 | ##################### add qk_norm ##################### 93 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 94 | 95 | mixed_query_layer = self.query(hidden_states) 96 | 97 | ##################### add qk_norm ##################### 98 | query_layer = self.transpose_for_scores(mixed_query_layer) 99 | query_layer = self.q_norm(query_layer) 100 | ##################### add qk_norm ##################### 101 | 102 | past_key_value = (key_layer, value_layer) 103 | 104 | # Take the dot product between "query" and "key" to get the raw attention scores. 105 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 106 | 107 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 108 | seq_length = hidden_states.size()[1] 109 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 110 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 111 | distance = position_ids_l - position_ids_r 112 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 113 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 114 | 115 | if self.position_embedding_type == "relative_key": 116 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 117 | attention_scores = attention_scores + relative_position_scores 118 | elif self.position_embedding_type == "relative_key_query": 119 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 120 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 121 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 122 | 123 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 124 | 125 | if attention_mask is not None: 126 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 127 | attention_scores = attention_scores + attention_mask 128 | 129 | # Normalize the attention scores to probabilities. 130 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 131 | 132 | if is_cross_attention and self.save_attention: 133 | self.save_attention_map(attention_probs) 134 | attention_probs.register_hook(self.save_attn_gradients) 135 | 136 | # This is actually dropping out entire tokens to attend to, which might 137 | # seem a bit unusual, but is taken from the original Transformer paper. 138 | attention_probs_dropped = self.dropout(attention_probs) 139 | 140 | # Mask heads if we want to 141 | if head_mask is not None: 142 | attention_probs_dropped = attention_probs_dropped * head_mask 143 | 144 | context_layer = torch.matmul(attention_probs_dropped, value_layer) 145 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 146 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 147 | context_layer = context_layer.view(*new_context_layer_shape) 148 | 149 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 150 | 151 | outputs = outputs + (past_key_value,) 152 | return outputs 153 | 154 | 155 | def replace_blip2_attn_with_qknorm_attn(): 156 | transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention = Blip2QFormerMultiHeadAttention 157 | print('replace Blip2QFormerMultiHeadAttention to support qk_norm') 158 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/monkey_patch/llama_flash_attn_train_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | from einops import rearrange 7 | from flash_attn.flash_attn_interface import flash_attn_func 8 | 9 | # ADAPTED from https://github.com/allenai/open-instruct/blob/main/open_instruct/llama_flash_attn_monkey_patch.py 10 | # AND https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py 11 | # AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py 12 | # AND Sourabh https://github.com/huggingface/transformers/commit/ee81bf5aee0d65f005d157c013777e3d27d8d6bf 13 | 14 | 15 | def _rotate_half_train(x): 16 | """Rotates half the hidden dims of the input.""" 17 | x1 = x[..., : x.shape[-1] // 2] 18 | x2 = x[..., x.shape[-1] // 2 :] 19 | return torch.cat((-x2, x1), dim=-1) 20 | 21 | 22 | # @torch.compile(mode="reduce-overhead") 23 | def _apply_rotary_pos_emb_train(q, k, cos, sin, position_ids): 24 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 25 | # q,k : [bsz, q_len, nh, hd] 26 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 27 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 28 | cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] 29 | sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] 30 | q_embed = (q * cos) + (_rotate_half_train(q) * sin) 31 | k_embed = (k * cos) + (_rotate_half_train(k) * sin) 32 | return q_embed, k_embed 33 | 34 | 35 | def _forward_train( 36 | self, 37 | hidden_states: torch.Tensor, 38 | attention_mask: Optional[torch.Tensor] = None, 39 | position_ids: Optional[torch.Tensor] = None, 40 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 41 | output_attentions: bool = False, 42 | use_cache: bool = False, 43 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 44 | """Input shape: Batch x Time x Channel 45 | 46 | attention_mask: [bsz, q_len] 47 | """ 48 | if output_attentions: 49 | warnings.warn( 50 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 51 | ) 52 | 53 | bsz, q_len, _ = hidden_states.size() 54 | 55 | query_states = self.q_proj(hidden_states).view( 56 | bsz, q_len, self.num_heads, self.head_dim 57 | ) 58 | key_states = self.k_proj(hidden_states).view( 59 | bsz, q_len, self.num_heads, self.head_dim 60 | ) 61 | value_states = self.v_proj(hidden_states).view( 62 | bsz, q_len, self.num_heads, self.head_dim 63 | ) 64 | # [bsz, q_len, nh, hd] -> [bsz, nh, q_len, hd] 65 | 66 | cos, sin = self.rotary_emb(value_states, seq_len=q_len) 67 | query_states, key_states = _apply_rotary_pos_emb_train( 68 | query_states, key_states, cos, sin, position_ids 69 | ) 70 | 71 | # Flash attention codes from 72 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 73 | # only work for training, not using key padding 74 | # q: (batch_size, seqlen, nheads, headdim) 75 | # k: (batch_size, seqlen, nheads_k, headdim) 76 | # v: (batch_size, seqlen, nheads_k, headdim) 77 | # out: (batch_size, seqlen, nheads, headdim) 78 | output = flash_attn_func( 79 | query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True 80 | ) 81 | 82 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 83 | 84 | 85 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 86 | # requires the attention mask to be the same as the key_padding_mask 87 | def _prepare_decoder_attention_mask_train( 88 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 89 | ): 90 | # [bsz, seq_len] 91 | return attention_mask 92 | 93 | 94 | def replace_old_func_with_new_func_only_for_train(old_func, new_func): 95 | def combined_func( 96 | self, 97 | *args, 98 | **kwargs, 99 | ): 100 | if self.training: 101 | return new_func(self, *args, **kwargs) 102 | else: 103 | return old_func(self, *args, **kwargs) 104 | 105 | return combined_func 106 | 107 | 108 | def replace_llama_attn_with_flash_attn(): 109 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 110 | if cuda_major < 8: 111 | warnings.warn( 112 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 113 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 114 | ) 115 | # for original llama in transformers 116 | import transformers.models.llama.modeling_llama as llama 117 | llama.LlamaModel._prepare_decoder_attention_mask = ( 118 | replace_old_func_with_new_func_only_for_train( 119 | llama.LlamaModel._prepare_decoder_attention_mask, 120 | _prepare_decoder_attention_mask_train, 121 | ) 122 | ) 123 | llama.LlamaAttention.forward = replace_old_func_with_new_func_only_for_train( 124 | llama.LlamaAttention.forward, _forward_train 125 | ) 126 | # for our text decoder 127 | import mm_interleaved.models.decoders.decoder_text as decoder_text 128 | decoder_text.TextDecoder._prepare_decoder_attention_mask = ( 129 | replace_old_func_with_new_func_only_for_train( 130 | decoder_text.TextDecoder._prepare_decoder_attention_mask, 131 | _prepare_decoder_attention_mask_train, 132 | ) 133 | ) 134 | import mm_interleaved.models.decoders.modeling_llama_mmfs as llama 135 | llama.LlamaModel._prepare_decoder_attention_mask = ( 136 | replace_old_func_with_new_func_only_for_train( 137 | llama.LlamaModel._prepare_decoder_attention_mask, 138 | _prepare_decoder_attention_mask_train, 139 | ) 140 | ) 141 | llama.LlamaAttention.forward = replace_old_func_with_new_func_only_for_train( 142 | llama.LlamaAttention.forward, _forward_train 143 | ) 144 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 10 | 11 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from torch.cuda.amp import custom_bwd, custom_fwd 18 | try: 19 | import MultiScaleDeformableAttention as MSDA 20 | except: 21 | print("MultiScaleDeformableAttention is not installed") 22 | 23 | 24 | class MSDeformAttnFunction(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 28 | ctx.im2col_step = im2col_step 29 | output = MSDA.ms_deform_attn_forward( 30 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 31 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 32 | return output 33 | 34 | @staticmethod 35 | @once_differentiable 36 | @custom_bwd 37 | def backward(ctx, grad_output): 38 | grad_output = grad_output.contiguous() 39 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 40 | grad_value, grad_sampling_loc, grad_attn_weight = \ 41 | MSDA.ms_deform_attn_backward( 42 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 43 | 44 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 45 | 46 | 47 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 48 | # for debug and test only, 49 | # need to use cuda version instead 50 | N_, S_, M_, D_ = value.shape 51 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 52 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 53 | sampling_grids = 2 * sampling_locations - 1 54 | sampling_value_list = [] 55 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 56 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 57 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 58 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 59 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 60 | # N_*M_, D_, Lq_, P_ 61 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 62 | mode='bilinear', padding_mode='zeros', align_corners=False) 63 | sampling_value_list.append(sampling_value_l_) 64 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 65 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 66 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 67 | return output.transpose(1, 2).contiguous() 68 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Multi-Image Multi-Scale Feature Synchronizer 3 | # Modifed from Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from .mmfs import MMFS 11 | 12 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | # "-DCUDA_HAS_FP16=1", 42 | # "-D__CUDA_NO_HALF_OPERATORS__", 43 | # "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | # "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | at::Tensor ms_deform_attn_cuda_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 30 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 31 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 32 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 33 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 34 | 35 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 36 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 37 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 38 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 39 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 40 | 41 | const int batch = value.size(0); 42 | const int spatial_size = value.size(1); 43 | const int num_heads = value.size(2); 44 | const int channels = value.size(3); // batch, spatial_size, num_heads, channels = value.shape 45 | 46 | const int num_levels = spatial_shapes.size(0); // [num_levels, 2] 47 | 48 | const int num_query = sampling_loc.size(1); 49 | const int num_point = sampling_loc.size(4); 50 | 51 | const int im2col_step_ = std::min(batch, im2col_step); 52 | 53 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 54 | 55 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); // 初始化一个output 56 | 57 | const int batch_n = im2col_step_; 58 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 59 | auto per_value_size = spatial_size * num_heads * channels; 60 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 61 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 62 | for (int n = 0; n < batch/im2col_step_; ++n) 63 | { 64 | auto columns = output_n.select(0, n); 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] { 66 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 67 | value.data() + n * im2col_step_ * per_value_size, 68 | spatial_shapes.data(), 69 | level_start_index.data(), 70 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 71 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 72 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 73 | columns.data()); 74 | 75 | })); 76 | } 77 | 78 | output = output.view({batch, num_query, num_heads*channels}); 79 | 80 | return output; 81 | } 82 | 83 | 84 | std::vector ms_deform_attn_cuda_backward( 85 | const at::Tensor &value, 86 | const at::Tensor &spatial_shapes, 87 | const at::Tensor &level_start_index, 88 | const at::Tensor &sampling_loc, 89 | const at::Tensor &attn_weight, 90 | const at::Tensor &grad_output, 91 | const int im2col_step) 92 | { 93 | 94 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 95 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 96 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 97 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 98 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 99 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 100 | 101 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 102 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 103 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 104 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 105 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 106 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 107 | 108 | const int batch = value.size(0); 109 | const int spatial_size = value.size(1); 110 | const int num_heads = value.size(2); 111 | const int channels = value.size(3); // batch, spatial_size, num_heads, channels = value.shape 112 | 113 | const int num_levels = spatial_shapes.size(0); 114 | 115 | const int num_query = sampling_loc.size(1); 116 | const int num_point = sampling_loc.size(4); 117 | 118 | const int im2col_step_ = std::min(batch, im2col_step); 119 | 120 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 121 | 122 | auto dtype = value.dtype(); 123 | if(dtype == at::kHalf){ 124 | dtype = at::kFloat; 125 | } 126 | 127 | auto grad_value = at::zeros_like(value, dtype); // value的梯度 128 | auto grad_sampling_loc = at::zeros_like(sampling_loc, dtype); // sampling loc的梯度 129 | auto grad_attn_weight = at::zeros_like(attn_weight, dtype); // attn weight的梯度 130 | 131 | const int batch_n = im2col_step_; 132 | auto per_value_size = spatial_size * num_heads * channels; // 每个value的大小 133 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; //每个sample loc的大小 134 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; // 每个attn weight的大小 135 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 136 | 137 | for (int n = 0; n < batch/im2col_step_; ++n) // col2im 138 | { 139 | auto grad_output_g = grad_output_n.select(0, n); 140 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] { 141 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 142 | grad_output_g.data(), 143 | value.data() + n * im2col_step_ * per_value_size, 144 | spatial_shapes.data(), 145 | level_start_index.data(), 146 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 147 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 148 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 149 | grad_value.data() + n * im2col_step_ * per_value_size, 150 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 151 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 152 | 153 | })); 154 | } 155 | 156 | if(value.dtype() == torch::kHalf){ 157 | return { 158 | grad_value.to(torch::kHalf), grad_sampling_loc.to(torch::kHalf), grad_attn_weight.to(torch::kHalf) 159 | }; 160 | } 161 | else{ 162 | return { 163 | grad_value, grad_sampling_loc, grad_attn_weight 164 | }; 165 | } 166 | } -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/mm_interleaved/models/utils/ops/tests/__init__.py -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/tests/compare_with_data.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import gradcheck 6 | import torch.nn.functional as F 7 | 8 | import copy 9 | import pandas as pd 10 | from collections import OrderedDict 11 | import matplotlib.pyplot as plt 12 | 13 | from functions.ms_deform_attn_func import MSDeformAttnFunction 14 | 15 | torch.manual_seed(0) 16 | 17 | 18 | def test_op(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step): 19 | return MSDeformAttnFunction.apply( 20 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step) 21 | 22 | def to_dtype(inputs, dtype): 23 | new_inputs = list(copy.deepcopy(inputs)) 24 | for i, x in enumerate(new_inputs): 25 | if isinstance(x, torch.Tensor) and torch.is_floating_point(x): 26 | new_inputs[i] = x.to(dtype).detach() 27 | new_inputs[i].requires_grad = x.requires_grad 28 | return tuple(new_inputs) 29 | 30 | 31 | def compare_tensor_dict(tensors1, tensors2, skip_t2_zero=False): 32 | result = OrderedDict() 33 | for name in tensors1.keys(): 34 | t1 = tensors1[name] 35 | t2 = tensors2[name] 36 | if skip_t2_zero: 37 | rel_mask = torch.abs(t2) > 1e-6 38 | else: 39 | rel_mask = torch.ones(t2) > 1 40 | abs_err, rel_err, mean_rel_err = calc_err( 41 | t1.double(), 42 | t2.double(), 43 | rel_mask=rel_mask) 44 | result[f'{name}_max_ae'] = abs_err 45 | result[f'{name}_max_re'] = rel_err 46 | result[f'{name}_avg_re'] = mean_rel_err 47 | return result 48 | 49 | 50 | def calc_statistic(x): 51 | return { 52 | 'avg': torch.mean(x).item(), 53 | 'std': torch.std(x, unbiased=False).item(), 54 | 'min': torch.min(x).item(), 55 | 'max': torch.max(x).item() 56 | } 57 | 58 | def calc_err(x1, x2, rel_mask = None): 59 | err = (x1 - x2).abs() 60 | max_abs_err = err.max() 61 | max_rel_err = (err[rel_mask] / x2.abs()[rel_mask]).max() 62 | mean_rel_err = (err[rel_mask] / x2.abs()[rel_mask]).mean() 63 | return max_abs_err.item(), max_rel_err.item(), mean_rel_err.item() 64 | 65 | def print_table(results): 66 | df = pd.DataFrame(None) 67 | for i, ret in enumerate(results): 68 | if i == 0: 69 | for k in ret.keys(): 70 | df[k] = [] 71 | new_row = pd.DataFrame(ret, index=[0]) 72 | df = pd.concat([df, new_row], ignore_index=True) 73 | 74 | with pd.option_context( 75 | 'display.max_rows', None, 'display.max_columns', None, 76 | 'display.max_colwidth', None, 'display.width', None, 77 | 'display.precision', 2, 78 | ): 79 | print(df) 80 | 81 | 82 | def internally_fp32(): 83 | data_file = 'data/fp64_data.pkl' 84 | test_data_list = torch.load(data_file) 85 | 86 | results = [] 87 | for test_data in test_data_list: 88 | inputs_fp64 = test_data['inputs'] 89 | grads_fp64 = test_data['grads'] 90 | grads_fp64 = [x for x in grads_fp64 if x is not None] 91 | outs_fp64 = test_data['outs'] 92 | 93 | inputs_fp32 = to_dtype(inputs_fp64, torch.float32) 94 | outs_fp32 = test_op(*inputs_fp32) 95 | sum_outs_fp32 = torch.sum(outs_fp32) 96 | sum_outs_fp32.backward() 97 | grads_fp32 = [x.grad.detach().clone() for x in inputs_fp32 if hasattr(x, 'grad') and x.grad is not None] 98 | grads_fp16 = to_dtype(grads_fp32, torch.float16) 99 | outs_fp16 = outs_fp32.to(torch.float16) 100 | 101 | tensor_fp16 = OrderedDict( 102 | {'query_grad': grads_fp16[0], 'offset_grad': grads_fp16[1], 'attn_grad': grads_fp16[2], 'out': outs_fp16}) 103 | tensor_fp64 = OrderedDict( 104 | {'query_grad': grads_fp64[0], 'offset_grad': grads_fp64[1], 'attn_grad': grads_fp64[2], 'out': outs_fp64}) 105 | 106 | ret = compare_tensor_dict(tensor_fp16, tensor_fp64, skip_t2_zero=True) 107 | results.append(ret) 108 | 109 | print_table(results) 110 | 111 | 112 | def compare_with_data(): 113 | data_file = 'data/fp64_data.pkl' 114 | test_data_list = torch.load(data_file) 115 | 116 | results = [] 117 | for test_data in test_data_list: 118 | inputs_fp64 = test_data['inputs'] 119 | grads_fp64 = test_data['grads'] 120 | grads_fp64 = [x for x in grads_fp64 if x is not None] 121 | outs_fp64 = test_data['outs'] 122 | 123 | inputs_fp16 = to_dtype(inputs_fp64, torch.float16) 124 | outs_fp16 = test_op(*inputs_fp16) 125 | sum_outs_fp16 = torch.sum(outs_fp16) 126 | sum_outs_fp16.backward() 127 | grads_fp16 = [x.grad.detach().clone() for x in inputs_fp16 if hasattr(x, 'grad') and x.grad is not None] 128 | 129 | tensor_fp16 = OrderedDict( 130 | {'query_grad': grads_fp16[0], 'offset_grad': grads_fp16[1], 'attn_grad': grads_fp16[2], 'out': outs_fp16}) 131 | tensor_fp64 = OrderedDict( 132 | {'query_grad': grads_fp64[0], 'offset_grad': grads_fp64[1], 'attn_grad': grads_fp64[2], 'out': outs_fp64}) 133 | 134 | ret = compare_tensor_dict(tensor_fp16, tensor_fp64, skip_t2_zero=True) 135 | results.append(ret) 136 | 137 | print_table(results) 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | compare_with_data() 143 | # internally_fp32() -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/tests/create_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from collections import OrderedDict 6 | import pickle as pkl 7 | import copy 8 | 9 | from functions.ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | def generate_inputs(dtype, bs=1, n_levels=2, shapes=((6,4), (3,2)), n_query=2, n_points=2, n_heads=2, head_dim=4): 12 | assert len(shapes) == n_levels 13 | shapes = torch.as_tensor(list(shapes), dtype=torch.long) 14 | assert shapes.shape[1] == 2 15 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 16 | spatial_size = sum([(H*W).item() for H, W in shapes]) 17 | 18 | value = torch.rand(bs, spatial_size, n_heads, head_dim) 19 | sampling_locations = torch.rand(bs, n_query, n_heads, n_levels, n_points, 2) 20 | attention_weights = torch.rand(bs, n_query, n_heads, n_levels, n_points).cuda() + 1e-5 21 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 22 | im2col_step = 2 23 | return ( 24 | value.cuda().to(dtype).requires_grad_(True), 25 | shapes.cuda(), 26 | level_start_index.cuda(), 27 | sampling_locations.cuda().to(dtype).requires_grad_(True), 28 | attention_weights.cuda().to(dtype).requires_grad_(True), 29 | im2col_step 30 | ) 31 | 32 | def test_op(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step): 33 | return MSDeformAttnFunction.apply( 34 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step) 35 | 36 | def to_dtype(inputs, dtype): 37 | new_inputs = list(copy.deepcopy(inputs)) 38 | for i, x in enumerate(new_inputs): 39 | if isinstance(x, torch.Tensor) and torch.is_floating_point(x): 40 | new_inputs[i] = x.to(dtype).detach() 41 | new_inputs[i].requires_grad = x.requires_grad 42 | return tuple(new_inputs) 43 | 44 | 45 | if __name__ == '__main__': 46 | out_file = 'data/fp64_data.pkl' 47 | data_list = [] 48 | for i in range(20): 49 | inputs = generate_inputs(torch.float16, 64) 50 | inputs = to_dtype(inputs, torch.float64) 51 | inputs[0].requires_grad = True 52 | inputs[3].requires_grad = True 53 | inputs[4].requires_grad = True 54 | 55 | outs = test_op(*inputs) 56 | outs.sum().backward() 57 | 58 | grads = [x.grad if hasattr(x, 'grad') else None for x in inputs] 59 | 60 | data_list.append(OrderedDict({ 61 | 'inputs': inputs, 62 | 'grads': grads, 63 | 'outs': outs 64 | })) 65 | 66 | torch.save(data_list, out_file) 67 | 68 | 69 | -------------------------------------------------------------------------------- /mm_interleaved/models/utils/ops/tests/speed_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import time 4 | import torch 5 | 6 | import copy 7 | 8 | from functions.ms_deform_attn_func import MSDeformAttnFunction 9 | 10 | from tests.create_data import generate_inputs 11 | from easydict import EasyDict as edict 12 | 13 | torch.manual_seed(0) 14 | 15 | def test_op(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step): 16 | return MSDeformAttnFunction.apply( 17 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step 18 | ) 19 | 20 | 21 | def to_dtype(inputs, dtype): 22 | new_inputs = list(copy.deepcopy(inputs)) 23 | for i, x in enumerate(new_inputs): 24 | if isinstance(x, torch.Tensor) and torch.is_floating_point(x): 25 | new_inputs[i] = x.to(dtype).detach() 26 | new_inputs[i].requires_grad = x.requires_grad 27 | return tuple(new_inputs) 28 | 29 | 30 | def run(module, args, name='Unknown'): 31 | inputs = generate_inputs(args.dtype, **args.data_args) 32 | 33 | # cudnn warmup 34 | for _ in range(50): 35 | if args.backward: 36 | module(*inputs).sum().backward() 37 | else: 38 | module(*inputs) 39 | 40 | torch.cuda.synchronize() 41 | t0 = time.time() 42 | 43 | for _ in range(args.num_iter): 44 | if args.backward: 45 | module(*inputs).sum().backward() 46 | else: 47 | module(*inputs) 48 | 49 | torch.cuda.synchronize() 50 | t1 = time.time() 51 | 52 | avg_time = (t1 - t0) * 1000 / args.num_iter 53 | print( 54 | f'>>> {name} finished {args.num_iter} running, avg_time: {avg_time:.6f} ms') 55 | return avg_time 56 | 57 | def info_memory(msg=None): 58 | if msg: 59 | print(msg) 60 | print(f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ 61 | Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ 62 | CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \ 63 | Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB ") 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | data_args = edict() 70 | data_args.bs = 32 71 | data_args.n_levels = 2 72 | data_args.shapes=[(16,16), (8,8)] 73 | data_args.n_query = 128 74 | data_args.n_points = 64 75 | data_args.n_heads = 8 76 | data_args.head_dim = 128 77 | 78 | args = edict() 79 | args.num_iter = 200 80 | args.backward = True 81 | args.dtype = torch.float16 82 | args.data_args = data_args 83 | 84 | run(test_op, args, name='fp16') 85 | info_memory() 86 | args.dtype = torch.float32 87 | run(test_op, args, name='fp32') 88 | info_memory() -------------------------------------------------------------------------------- /mm_interleaved/scripts/download_hf_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import CLIPModel, CLIPProcessor 5 | from transformers import LlamaTokenizer, LlamaForCausalLM 6 | from diffusers import StableDiffusionPipeline 7 | 8 | version = 'lmsys/vicuna-13b-v1.3' 9 | path = os.path.join('./assets', version) 10 | os.makedirs(path, exist_ok=True) 11 | llm_tokenizer:LlamaTokenizer = LlamaTokenizer.from_pretrained(version) 12 | llm_tokenizer.save_pretrained(path) 13 | llm_model = LlamaForCausalLM.from_pretrained(version, force_download=True, resume_download=False) 14 | llm_model.save_pretrained(path) 15 | 16 | version = "openai/clip-vit-large-patch14" 17 | clip_model = CLIPModel.from_pretrained(version) 18 | clip_processor = CLIPProcessor.from_pretrained(version) 19 | path = os.path.join('./assets', version) 20 | os.makedirs(path, exist_ok=True) 21 | clip_model.save_pretrained(path) 22 | clip_processor.save_pretrained(path) 23 | 24 | version = 'stabilityai/stable-diffusion-2-base' 25 | path = os.path.join('./assets', version) 26 | os.makedirs(path, exist_ok=True) 27 | pipe = StableDiffusionPipeline.from_pretrained(version, torch_dtype=torch.float32) 28 | pipe.save_pretrained(path) 29 | -------------------------------------------------------------------------------- /mm_interleaved/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .parse_args import ArgumentParser, TrainingArguments 2 | from .misc import init_distributed_mode, load_model_weights 3 | from .caption_collect import collect_caption_result 4 | from .vqa_collect import collect_vqa_result 5 | -------------------------------------------------------------------------------- /mm_interleaved/utils/caption_collect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .misc import barrier, get_rank, get_world_size 5 | 6 | 7 | def collect_caption_result(result, result_dir, filename, remove_duplicate=''): 8 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, get_rank())) 9 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 10 | 11 | json.dump(result, open(result_file, 'w')) 12 | 13 | barrier() 14 | 15 | if get_rank() == 0: 16 | # combine results from all processes 17 | result = [] 18 | 19 | for rank in range(get_world_size()): 20 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, rank)) 21 | res = json.load(open(result_file, 'r')) 22 | result += res 23 | os.remove(result_file) 24 | 25 | if remove_duplicate: 26 | result_new = [] 27 | id_list = set() 28 | for res in result: 29 | if res[remove_duplicate] not in id_list: 30 | id_list.add(res[remove_duplicate]) 31 | result_new.append(res) 32 | result = result_new 33 | 34 | json.dump(result, open(final_result_file, 'w')) 35 | print('result file saved to %s' % final_result_file) 36 | 37 | return final_result_file -------------------------------------------------------------------------------- /mm_interleaved/utils/clip_sim_score.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | from transformers import CLIPModel 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader, DistributedSampler 7 | 8 | from ..utils.misc import MetricLogger, barrier, get_world_size, get_rank 9 | from ..custom_datasets.clip_itp import CLIPImagePairDataset 10 | 11 | 12 | def tensor_to_pil(images: torch.Tensor): 13 | pil_images = images.mul(255).add_(0.5).clamp_(0, 255) 14 | pil_images = [ 15 | Image.fromarray(img.permute(1, 2, 0).to("cpu", torch.uint8).numpy()).convert("RGB") 16 | for img in pil_images 17 | ] 18 | return pil_images 19 | 20 | 21 | @torch.no_grad() 22 | def calculate_clip_sim_i2i( 23 | image_list, 24 | model_name="./assets/openai/clip-vit-large-patch14", 25 | device="cuda", 26 | batch_size=2048, 27 | ): 28 | if isinstance(image_list, str): 29 | image_list = json.load(open(image_list), "r") 30 | clip_model = CLIPModel.from_pretrained(model_name) 31 | clip_model.to(device) 32 | clip_model.eval() 33 | 34 | clip_dataset = CLIPImagePairDataset(image_list, model_name) 35 | print(clip_dataset) 36 | num_tasks = get_world_size() 37 | global_rank = get_rank() 38 | sampler = DistributedSampler( 39 | clip_dataset, 40 | num_replicas=num_tasks, 41 | rank=global_rank, 42 | shuffle=False, 43 | ) 44 | mini_batch_size = batch_size // num_tasks 45 | data_loader = DataLoader( 46 | clip_dataset, 47 | sampler=sampler, 48 | batch_size=mini_batch_size, 49 | drop_last=False, 50 | num_workers=10, 51 | pin_memory=True, 52 | ) 53 | 54 | metric_logger = MetricLogger(delimiter=" ") 55 | header = "Eval CLIP similarity i2i: " 56 | print_freq = 5 57 | 58 | for batch_idx, data in enumerate( 59 | metric_logger.log_every(data_loader, print_freq, header) 60 | ): 61 | image, image_gt, image_idx = data 62 | image = image.to(device, non_blocking=True) 63 | image_gt = image_gt.to(device, non_blocking=True) 64 | 65 | image_feat = clip_model.get_image_features(pixel_values=image) 66 | image_gt_feat = clip_model.get_image_features(pixel_values=image_gt) 67 | 68 | # Compute cosine similarity. 69 | image_feat = F.normalize(image_feat, dim=-1) 70 | image_gt_feat = F.normalize(image_gt_feat, dim=-1) 71 | scores = (image_feat * image_gt_feat).sum(dim=-1) 72 | metric_logger.meters["clip_sim_i2i"].update(scores.mean(), n=scores.shape[0]) 73 | 74 | barrier() 75 | # gather the stats from all processes 76 | metric_logger.synchronize_between_processes() 77 | print("Averaged stats:", metric_logger) 78 | score = metric_logger.meters["clip_sim_i2i"].global_avg 79 | print("CLIP similarity:", score) 80 | return score 81 | 82 | 83 | @torch.no_grad() 84 | def clip_rerank_generated_images( 85 | images: torch.Tensor, 86 | captions, 87 | clip_model, 88 | clip_processor, 89 | device="cuda", 90 | ): 91 | _images = tensor_to_pil(images) 92 | images = _images 93 | 94 | bs = len(captions) 95 | num_candidates = len(images) // len(captions) 96 | 97 | data = clip_processor( 98 | images=_images, 99 | text=captions, 100 | return_tensors="pt", 101 | padding="max_length", 102 | max_length=77, 103 | ) 104 | 105 | image_tensors = data.pixel_values.to(device=device) 106 | text_ids = data.input_ids.to(device=device) 107 | 108 | image_feat = clip_model.get_image_features(pixel_values=image_tensors) 109 | image_feat = F.normalize(image_feat, dim=-1) 110 | text_feat = clip_model.get_text_features(input_ids=text_ids) 111 | text_feat = F.normalize(text_feat, dim=-1) 112 | text_feat = text_feat.repeat(num_candidates, 1) 113 | 114 | scores = (image_feat * text_feat).sum(dim=-1) 115 | scores = scores.view(num_candidates, -1).transpose(0, 1) 116 | 117 | best_image_idxs = scores.argmax(dim=1) 118 | best_images = [images[idx * bs + i] for i,idx in enumerate(best_image_idxs)] 119 | 120 | return best_images 121 | -------------------------------------------------------------------------------- /mm_interleaved/utils/coco_cap_score.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pycocoevalcap.eval import COCOEvalCap 3 | import os 4 | import json 5 | 6 | 7 | def coco_caption_eval( 8 | annotation_file, 9 | results_file, 10 | phase="test", 11 | use_1st_sentence_only=False, 12 | ): 13 | 14 | # we use the test dataset as the evaluation 15 | annotation_file = annotation_file.replace( 16 | f"coco_karpathy_{phase}.json", f"coco_karpathy_{phase}_gt.json" 17 | ) 18 | # create coco object and coco_result object 19 | coco = COCO(annotation_file) 20 | 21 | with open(results_file) as f: 22 | anns = json.load(f) 23 | if use_1st_sentence_only: 24 | for ann in anns: 25 | ann["caption"] = ann["caption"].split(".")[0] 26 | coco_result = coco.loadRes(anns) 27 | 28 | # create coco_eval object by taking coco and coco_result 29 | coco_eval = COCOEvalCap(coco, coco_result) 30 | 31 | # evaluate on a subset of images by setting 32 | # coco_eval.params['image_id'] = coco_result.getImgIds() 33 | # please remove this line when evaluating the full validation set 34 | coco_eval.params["image_id"] = coco_result.getImgIds() 35 | 36 | try: 37 | # evaluate results 38 | # SPICE will take a few minutes the first time, but speeds up due to caching 39 | coco_eval.evaluate() 40 | except Exception as exp: 41 | print(exp) 42 | return {} 43 | 44 | # print output evaluation scores 45 | return coco_eval.eval 46 | -------------------------------------------------------------------------------- /mm_interleaved/utils/grounding_score.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | from torchvision.ops.boxes import box_area 5 | 6 | def box_iou(boxes1, boxes2): 7 | area1 = box_area(boxes1) 8 | area2 = box_area(boxes2) 9 | 10 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 11 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 12 | 13 | wh = (rb - lt).clamp(min=0) # [N,M,2] 14 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 15 | 16 | union = area1[:, None] + area2 - inter 17 | 18 | iou = inter / union 19 | return iou, union 20 | 21 | def parse_box(box_str): 22 | PATTERN = re.compile(r'\((.*?)\)\((.*?)\)') 23 | predict_bbox = re.findall(PATTERN, box_str) 24 | 25 | try: 26 | if ',' not in predict_bbox[0][0] or ',' not in predict_bbox[0][1]: 27 | predict_bbox = (0., 0., 0., 0.) 28 | else: 29 | x1, y1 = [ 30 | float(tmp) for tmp in predict_bbox[0][0].split(',') 31 | ] 32 | x2, y2 = [ 33 | float(tmp) for tmp in predict_bbox[0][1].split(',') 34 | ] 35 | predict_bbox = (x1, y1, x2, y2) 36 | except: 37 | predict_bbox = (0., 0., 0., 0.) 38 | 39 | return predict_bbox 40 | 41 | def grounding_eval(results_file): 42 | results = json.load(open(results_file)) 43 | 44 | total_cnt = 0 45 | correct = 0 46 | for item in results: 47 | gt_box = item['gt_box'] 48 | pred_box = item['pred_box'] 49 | h = item['height'] 50 | w = item['width'] 51 | 52 | pred_box = parse_box(pred_box) 53 | pred_box = torch.tensor(pred_box, dtype=torch.float32).view(-1, 4) / 999 54 | pred_box[:, 0::2] *= w 55 | pred_box[:, 1::2] *= h 56 | 57 | gt_box = torch.tensor(gt_box, dtype=torch.float32).view(-1, 4) / 999 58 | gt_box[:, 0::2] *= w 59 | gt_box[:, 1::2] *= h 60 | 61 | iou, _ = box_iou(pred_box, gt_box) 62 | iou = iou.item() 63 | total_cnt += 1 64 | if iou >= 0.5: 65 | correct += 1 66 | 67 | return {'accuracy': correct / total_cnt} 68 | -------------------------------------------------------------------------------- /mm_interleaved/utils/parse_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | from dataclasses import dataclass, field, fields 4 | 5 | from mmcv import Config 6 | import transformers 7 | from transformers.hf_argparser import HfArgumentParser, DataClass 8 | 9 | from .misc import is_main_process 10 | 11 | 12 | @dataclass 13 | class TrainingArguments(transformers.TrainingArguments): 14 | config_file: Optional[str] = field(default="./configs/debug.yaml") 15 | resume: Optional[bool] = field(default=True) 16 | 17 | output_dir: Optional[str] = field(default="./OUTPUT/debug") 18 | remove_unused_columns: Optional[bool] = field( 19 | default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} 20 | ) 21 | 22 | lr_for_random_params: Optional[float] = field(default=1e-3) 23 | random_params: Optional[str] = field(default=None) 24 | lr_for_random_params_list: Optional[List[str]]= field(default_factory=lambda: None) 25 | wd_for_random_params_list: Optional[List[str]]= field(default_factory=lambda: None) 26 | random_params_list: Optional[List[str]] = field(default_factory=lambda: None) 27 | 28 | generate_mode: Optional[str] = field(default="generate_texts") 29 | use_1st_sentence_only: Optional[bool] = field(default=False) 30 | 31 | 32 | class ArgumentParser(HfArgumentParser): 33 | def parse_args_with_config_file_into_dataclasses( 34 | self, 35 | args=None, 36 | return_remaining_strings=False, 37 | ) -> Tuple[DataClass, ...]: 38 | """ 39 | 1. parse system arguments 40 | 2. load yaml config file 41 | 3. merge arguments from 2. into 1., 42 | note that if there exists same arguments in both 2. and 1., 43 | then the arguments in 1. will be overwritten by that in 2. 44 | 4. split into different dataclasses 45 | """ 46 | namespace, remaining_args = self.parse_known_args(args=args) 47 | config_file = getattr(namespace, "config_file", "./configs/debug.yaml") 48 | config_args = Config.fromfile(config_file) 49 | namespace.__dict__.update(config_args) 50 | if is_main_process(): 51 | Config.dump(Config(namespace.__dict__), file=os.path.join(namespace.output_dir, "config.yaml")) 52 | 53 | outputs = [] 54 | for dtype in self.dataclass_types: 55 | keys = {f.name for f in fields(dtype) if f.init} 56 | inputs = {k: v for k, v in vars(namespace).items() if k in keys} 57 | for k in keys: 58 | delattr(namespace, k) 59 | obj = dtype(**inputs) 60 | outputs.append(obj) 61 | if len(namespace.__dict__) > 0: 62 | # additional namespace. 63 | outputs.append(namespace) 64 | if return_remaining_strings: 65 | return (*outputs, remaining_args) 66 | else: 67 | if remaining_args: 68 | raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") 69 | 70 | return (*outputs,) 71 | -------------------------------------------------------------------------------- /mm_interleaved/utils/segm_eval.py: -------------------------------------------------------------------------------- 1 | from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation 2 | from PIL import Image 3 | import numpy as np 4 | 5 | processor = None # OneFormerProcessor.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 6 | model = None # OneFormerForUniversalSegmentation.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 7 | 8 | 9 | def calculate_segm(image, gt_img): 10 | global processor 11 | global model 12 | if processor is None: 13 | processor = OneFormerProcessor.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 14 | if model is None: 15 | model = OneFormerForUniversalSegmentation.from_pretrained("./assets/shi-labs/oneformer_ade20k_dinat_large") 16 | 17 | semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt") 18 | semantic_outputs = model(**semantic_inputs) 19 | # pass through image_processor for postprocessing 20 | predicted_semantic_map = processor.post_process_semantic_segmentation(semantic_outputs, target_sizes=[gt_img.size[::-1]])[0] 21 | 22 | return predicted_semantic_map 23 | 24 | def intersectionAndUnion(imPred, imLab, numClass): 25 | imPred = np.asarray(imPred).copy() 26 | imLab = np.asarray(imLab).copy() 27 | 28 | # imPred += 1 29 | # imLab += 1 30 | # Remove classes from unlabeled pixels in gt image. 31 | # We should not penalize detections in unlabeled portions of the image. 32 | imPred = imPred * (imLab > 0) 33 | 34 | # Compute area intersection: 35 | intersection = imPred * (imPred == imLab) 36 | (area_intersection, _) = np.histogram( 37 | intersection, bins=numClass, range=(1, numClass)) 38 | 39 | # Compute area union: 40 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 41 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 42 | area_union = area_pred + area_lab - area_intersection 43 | 44 | return (area_intersection, area_union) 45 | 46 | 47 | def calculate_miou_given_paths(paths, num_classes=150): 48 | 49 | all_intersection = None 50 | all_union = None 51 | 52 | for path1, path2 in zip(*paths): 53 | seg_label = np.array(Image.open(path1)) 54 | pred = np.array(Image.open(path2)) + 1 55 | 56 | intersection, union = intersectionAndUnion(pred, seg_label, num_classes) 57 | all_intersection = intersection if all_intersection is None else all_intersection + intersection 58 | all_union = union if all_union is None else all_union + union 59 | 60 | iou = all_intersection / (all_union + 1e-10) 61 | 62 | miou = iou.mean() 63 | 64 | return miou 65 | 66 | -------------------------------------------------------------------------------- /mm_interleaved/utils/visdial_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Metric observes output of certain model, for example, in form of logits or 3 | scores, and accumulates a particular metric with reference to some provided 4 | targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean 5 | Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). 6 | 7 | Each ``Metric`` must atleast implement three methods: 8 | - ``observe``, update accumulated metric with currently observed outputs 9 | and targets. 10 | - ``retrieve`` to return the accumulated metric., an optionally reset 11 | internally accumulated metric (this is commonly done between two epochs 12 | after validation). 13 | - ``reset`` to explicitly reset the internally accumulated metric. 14 | 15 | Caveat, if you wish to implement your own class of Metric, make sure you call 16 | ``detach`` on output tensors (like logits), else it will cause memory leaks. 17 | """ 18 | import torch 19 | 20 | 21 | def scores_to_ranks(scores: torch.Tensor): 22 | """Convert model output scores into ranks.""" 23 | batch_size, num_rounds, num_options = scores.size() 24 | scores = scores.view(-1, num_options) 25 | 26 | # sort in descending order - largest score gets highest rank 27 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 28 | 29 | # i-th position in ranked_idx specifies which score shall take this 30 | # position but we want i-th position to have rank of score at that 31 | # position, do this conversion 32 | ranks = ranked_idx.clone().fill_(0) 33 | for i in range(ranked_idx.size(0)): 34 | for j in range(num_options): 35 | ranks[i][ranked_idx[i][j]] = j 36 | # convert from 0-99 ranks to 1-100 ranks 37 | ranks += 1 38 | ranks = ranks.view(batch_size, num_rounds, num_options) 39 | return ranks 40 | 41 | 42 | class SparseGTMetrics(object): 43 | """ 44 | A class to accumulate all metrics with sparse ground truth annotations. 45 | These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. 46 | """ 47 | 48 | def __init__(self): 49 | self._rank_list = [] 50 | 51 | def observe(self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor): 52 | predicted_scores = predicted_scores.detach() 53 | 54 | # shape: (batch_size, num_rounds, num_options) 55 | predicted_ranks = scores_to_ranks(predicted_scores) 56 | batch_size, num_rounds, num_options = predicted_ranks.size() 57 | 58 | # collapse batch dimension 59 | predicted_ranks = predicted_ranks.view(batch_size * num_rounds, num_options) 60 | 61 | # shape: (batch_size * num_rounds, ) 62 | target_ranks = target_ranks.view(batch_size * num_rounds).long() 63 | 64 | # shape: (batch_size * num_rounds, ) 65 | predicted_gt_ranks = predicted_ranks[ 66 | torch.arange(batch_size * num_rounds), target_ranks 67 | ] 68 | self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) 69 | 70 | def retrieve(self, reset: bool = True): 71 | num_examples = len(self._rank_list) 72 | if num_examples > 0: 73 | # convert to numpy array for easy calculation. 74 | __rank_list = torch.tensor(self._rank_list).float() 75 | metrics = { 76 | "r@1": torch.mean((__rank_list <= 1).float()).item(), 77 | "r@5": torch.mean((__rank_list <= 5).float()).item(), 78 | "r@10": torch.mean((__rank_list <= 10).float()).item(), 79 | "mean": torch.mean(__rank_list).item(), 80 | "mrr": torch.mean(__rank_list.reciprocal()).item(), 81 | } 82 | else: 83 | metrics = {} 84 | 85 | if reset: 86 | self.reset() 87 | return metrics 88 | 89 | def reset(self): 90 | self._rank_list = [] 91 | 92 | 93 | class NDCG(object): 94 | def __init__(self): 95 | self._ndcg_numerator = 0.0 96 | self._ndcg_denominator = 0.0 97 | 98 | def observe( 99 | self, 100 | target_relevance: torch.Tensor, 101 | predicted_scores: torch.Tensor = None, 102 | predicted_ranks: torch.Tensor = None, 103 | ): 104 | """ 105 | Observe model output scores and target ground truth relevance and 106 | accumulate NDCG metric. 107 | 108 | Parameters 109 | ---------- 110 | predicted_scores: torch.Tensor 111 | A tensor of shape (batch_size, num_options), because dense 112 | annotations are available for 1 randomly picked round out of 10. 113 | target_relevance: torch.Tensor 114 | A tensor of shape same as predicted scores, indicating ground truth 115 | relevance of each answer option for a particular round. 116 | """ 117 | if predicted_ranks is None: 118 | predicted_scores = predicted_scores.detach() 119 | 120 | # shape: (batch_size, 1, num_options) 121 | predicted_scores = predicted_scores.unsqueeze(1) 122 | predicted_ranks = scores_to_ranks(predicted_scores) 123 | 124 | # shape: (batch_size, num_options) 125 | predicted_ranks = predicted_ranks.squeeze() 126 | batch_size, num_options = predicted_ranks.size() 127 | 128 | k = torch.sum(target_relevance != 0, dim=-1) 129 | 130 | # shape: (batch_size, num_options) 131 | _, rankings = torch.sort(predicted_ranks, dim=-1) 132 | # Sort relevance in descending order so highest relevance gets top rnk. 133 | _, best_rankings = torch.sort(target_relevance, dim=-1, descending=True) 134 | 135 | # shape: (batch_size, ) 136 | batch_ndcg = [] 137 | for batch_index in range(batch_size): 138 | num_relevant = k[batch_index] 139 | dcg = self._dcg( 140 | rankings[batch_index][:num_relevant], 141 | target_relevance[batch_index], 142 | ) 143 | best_dcg = self._dcg( 144 | best_rankings[batch_index][:num_relevant], 145 | target_relevance[batch_index], 146 | ) 147 | batch_ndcg.append(dcg / best_dcg) 148 | 149 | self._ndcg_denominator += batch_size 150 | self._ndcg_numerator += sum(batch_ndcg) 151 | 152 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): 153 | sorted_relevance = relevance[rankings].cpu().float() 154 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2) 155 | return torch.sum(sorted_relevance / discounts, dim=-1) 156 | 157 | def retrieve(self, reset: bool = True): 158 | if self._ndcg_denominator > 0: 159 | metrics = {"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)} 160 | else: 161 | metrics = {} 162 | 163 | if reset: 164 | self.reset() 165 | return metrics 166 | 167 | def reset(self): 168 | self._ndcg_numerator = 0.0 169 | self._ndcg_denominator = 0.0 170 | -------------------------------------------------------------------------------- /mm_interleaved/utils/vizwiz_metrics_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/mm_interleaved/utils/vizwiz_metrics_src/__init__.py -------------------------------------------------------------------------------- /mm_interleaved/utils/vizwiz_metrics_src/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'QingLi' 2 | __version__ = '1.0' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Qing Li for VizWiz Python API available at the following link: 7 | # (https://github.com/xxx) 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, annotation=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.imgToQA = {} 34 | if annotation is not None or annotation_file is not None: 35 | print('loading dataset into memory...') 36 | time_t = datetime.datetime.utcnow() 37 | dataset = json.load(open(annotation_file, 'r')) if annotation is None else annotation 38 | print(datetime.datetime.utcnow() - time_t) 39 | self.dataset = dataset 40 | self.imgToQA = {x['image']: x for x in dataset} 41 | 42 | def getImgs(self): 43 | return list(self.imgToQA.keys()) 44 | 45 | def getAnns(self, imgs=[], ansTypes=[]): 46 | """ 47 | Get annotations that satisfy given filter conditions. default skips that filter 48 | :param imgs (str array): get annotations for given image names 49 | ansTypes (str array) : get annotations for given answer types 50 | :return: annotations (dict array) : dict array of annotations 51 | """ 52 | anns = self.dataset 53 | 54 | imgs = imgs if type(imgs) == list else [imgs] 55 | if len(imgs) != 0: 56 | anns = [self.imgToQA[img] for img in imgs] 57 | 58 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 59 | if len(ansTypes) != 0: 60 | anns = [ann for ann in anns if ann['answer_type'] in ansTypes] 61 | return anns 62 | 63 | def showQA(self, anns): 64 | """ 65 | Display the specified annotations. 66 | :param anns (array of object): annotations to display 67 | :return: None 68 | """ 69 | if len(anns) == 0: 70 | return 0 71 | for ann in anns: 72 | print("Question: %s" % ann['question']) 73 | print("Answer: ") 74 | print('\n'.join([x['answer'] for x in ann['answers']])) 75 | -------------------------------------------------------------------------------- /mm_interleaved/utils/vqa_collect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .misc import barrier, get_rank, get_world_size 5 | 6 | 7 | def collect_vqa_result(result, result_dir, filename, is_vizwiz=False): 8 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, get_rank())) 9 | final_result_file = os.path.join(result_dir, '%s.json' % filename) 10 | 11 | for item in result: 12 | image_id = item.pop("image_id") 13 | answer = item.pop("caption") 14 | 15 | if is_vizwiz: 16 | item['image'] = f'VizWiz_val_{image_id:08d}.jpg' 17 | else: 18 | item['question_id'] = image_id 19 | item['answer'] = answer 20 | 21 | json.dump(result, open(result_file, 'w')) 22 | 23 | barrier() 24 | if get_rank() == 0: 25 | # combine results from all processes 26 | result = [] 27 | 28 | for rank in range(get_world_size()): 29 | result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, rank)) 30 | res = json.load(open(result_file, 'r')) 31 | result += res 32 | os.remove(result_file) 33 | 34 | json.dump(result, open(final_result_file, 'w')) 35 | print('result file saved to %s' % final_result_file) 36 | 37 | return final_result_file -------------------------------------------------------------------------------- /mm_interleaved/utils/vqa_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from .vqav2_metrics_src.vqa import VQA as VQAV2_VQA 5 | from .vqav2_metrics_src.vqaEval import VQAEval as VQAV2_VQAEval 6 | from .vizwiz_metrics_src.vqa import VQA as Vizwiz_VQA 7 | from .vizwiz_metrics_src.vqaEval import VQAEval as Vizwiz_VQAEval 8 | 9 | def extract_answer(response): 10 | response = response.replace('\"', '') 11 | # response = response.strip().split('.')[0].split(',')[0].split('!')[0].lower() 12 | response = response.strip().split('\n')[0].split('.')[0].split(',')[0].split('!')[0].lower() 13 | 14 | if 'is ' in response: 15 | response = response.split('is ')[1] 16 | if 'are ' in response: 17 | response = response.split('are ')[1] 18 | if 'a ' in response: 19 | response = response.split('a ')[1] 20 | if 'an ' in response: 21 | response = response.split('an ')[1] 22 | if 'the ' in response: 23 | response = response.split('the ')[1] 24 | if ' of' in response: 25 | response = response.split(' of')[0] 26 | 27 | if ' or ' in response: 28 | response = response.split(' or ')[0] 29 | if ' and ' in response: 30 | response = response.split(' and ')[0] 31 | 32 | return response.strip() 33 | 34 | def vqa_eval( 35 | question_file, 36 | annotation_file, 37 | results_file, 38 | use_extract_answer=True, 39 | ): 40 | answers = json.load(open(results_file)) 41 | for item in answers: 42 | answer = item['answer'] 43 | 44 | if use_extract_answer: 45 | answer = extract_answer(answer) 46 | 47 | item['answer'] = answer 48 | 49 | if use_extract_answer: 50 | with open(results_file.replace('.json', '_processed.json'), 'w') as file: 51 | json.dump(answers, file) 52 | 53 | annotation_file = annotation_file 54 | question_file = question_file 55 | vqa = VQAV2_VQA(annotation_file, question_file) 56 | vqaRes = vqa.loadRes(answers, question_file) 57 | vqaEval = VQAV2_VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 58 | vqaEval.evaluate() 59 | 60 | return {'overall_accuracy': vqaEval.accuracy['overall']} 61 | 62 | def vizwiz_vqa_eval( 63 | annotation_file, 64 | results_file, 65 | use_extract_answer=True, 66 | ): 67 | answers = json.load(open(results_file)) 68 | for item in answers: 69 | answer = item['answer'] 70 | 71 | if use_extract_answer: 72 | answer = extract_answer(answer) 73 | 74 | item['answer'] = answer 75 | 76 | if use_extract_answer: 77 | with open(results_file.replace('.json', '_processed.json'), 'w') as file: 78 | json.dump(answers, file) 79 | 80 | vqa = Vizwiz_VQA(annotation_file) 81 | vqaRes = Vizwiz_VQA(annotation=answers) 82 | vqaEval = Vizwiz_VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 83 | vqaEval.evaluate() 84 | 85 | res = {'overall_accuracy': vqaEval.accuracy['overall']} 86 | res.update(vqaEval.caption_metric.items()) 87 | return res 88 | -------------------------------------------------------------------------------- /mm_interleaved/utils/vqav2_metrics_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MM-Interleaved/ac539d00ead0c438328ac1788849d560703a6b15/mm_interleaved/utils/vqav2_metrics_src/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | absl-py==1.4.0 3 | accelerate==0.21.0 4 | albumentations==1.3.1 5 | braceexpand==0.1.7 6 | datasets==2.14.0 7 | deepspeed==0.10.0 8 | diffusers==0.20.0 9 | einops==0.6.1 10 | fairscale==0.4.13 11 | flash-attn==2.0.4 12 | mmcv-full==1.7.0 13 | ninja==1.11.1 14 | nltk==3.8.1 15 | Pillow==10.0.0 16 | pyarrow==12.0.1 17 | pycocoevalcap==1.2 18 | pycocotools==2.0.6 19 | scikit-learn==1.3.1 20 | scipy==1.11.1 21 | sentencepiece==0.1.99 22 | timm==0.9.2 23 | tokenizers==0.13.3 24 | torch==2.0.1+cu117 25 | transformers==4.31.0 26 | triton==2.0.0 27 | webdataset==0.2.48 28 | xformers==0.0.20 29 | omegaconf==2.3.0 30 | peft==0.3.0 31 | -------------------------------------------------------------------------------- /slurm_run.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | 4 | GPUS=${1} 5 | GPUS_PER_NODE=${2} 6 | JOB_NAME=${3} 7 | QUOTATYPE=${4} 8 | PARTITION=${5} 9 | 10 | # GPUS=${GPUS:-8} 11 | # GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-10} 13 | 14 | if [ $GPUS -lt 8 ]; then 15 | NODE=1 16 | else 17 | NODE=$[GPUS/GPUS_PER_NODE] 18 | fi 19 | 20 | SCRIPT=${6} 21 | CONFIG=${7} 22 | 23 | CFGNAME=`basename ${CONFIG} .yaml` 24 | SCRIPTNAME=`basename ${SCRIPT} .py` 25 | DIR=./OUTPUT/${CFGNAME} 26 | mkdir -p ${DIR} 27 | SUFFIX=`date '+%Y%m%d%H%M'` 28 | 29 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 30 | 31 | SRUN_ARGS=${SRUN_ARGS:-""} 32 | PY_ARGS=${PY_ARGS:-""} 33 | 34 | export DISABLE_ADDMM_CUDA_LT=1 35 | export MASTER_PORT=22115 36 | export TORCH_CUDNN_USE_HEURISTIC_MODE_B=1 37 | srun -p ${PARTITION} \ 38 | --quotatype=${QUOTATYPE} \ 39 | --job-name=${JOB_NAME} \ 40 | --gres=gpu:${GPUS_PER_NODE} \ 41 | --ntasks=${GPUS} \ 42 | --ntasks-per-node=${GPUS_PER_NODE} \ 43 | --cpus-per-task=${CPUS_PER_TASK} \ 44 | --kill-on-bad-exit=1 \ 45 | ${SRUN_ARGS} \ 46 | python -u ${SCRIPT} --config_file=${CONFIG} --output_dir=${DIR} --run_name ${CFGNAME} \ 47 | ${@:8} ${PY_ARGS} 2>&1 | tee -a ${DIR}/${SCRIPTNAME}_${SUFFIX}.log 48 | #done -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mm_interleaved.models.utils.monkey_patch import ( 4 | replace_llama_attn_with_flash_attn, 5 | replace_blip2_attn_with_qknorm_attn, 6 | replace_beam_search, 7 | replace_stable_diffusion_pipeline_call, 8 | replace_stable_diffusion_unet_forward, 9 | ) 10 | 11 | replace_beam_search() 12 | replace_blip2_attn_with_qknorm_attn() 13 | replace_stable_diffusion_unet_forward() 14 | replace_stable_diffusion_pipeline_call() 15 | IS_TRAIN = True 16 | if IS_TRAIN: 17 | replace_llama_attn_with_flash_attn() 18 | 19 | 20 | from transformers.trainer_utils import get_last_checkpoint 21 | 22 | from mm_interleaved.models import MMInterleaved 23 | from mm_interleaved.custom_datasets.utils import build_dataset 24 | from mm_interleaved.engine.lmm_trainer import LMMTrainer 25 | from mm_interleaved.utils import ArgumentParser, TrainingArguments, init_distributed_mode, load_model_weights 26 | 27 | 28 | def main(): 29 | parser = ArgumentParser(TrainingArguments) 30 | init_distributed_mode() 31 | args = parser.parse_args_with_config_file_into_dataclasses() 32 | train_args, config = args 33 | print(train_args) 34 | print(config) 35 | 36 | print("Data Loading Start") 37 | train_dataset = build_dataset(config.data.train) 38 | print(train_dataset) 39 | eval_dataset = build_dataset(config.data.val) 40 | print(eval_dataset) 41 | 42 | print("Model Init Start") 43 | model = MMInterleaved(**config.model) 44 | print(model) 45 | 46 | print("Trainer Init Start") 47 | trainer = LMMTrainer( 48 | model=model, 49 | tokenizer=train_dataset.tokenizer, 50 | config=config, 51 | args=train_args, 52 | train_dataset=train_dataset, 53 | data_collator=train_dataset.collator, 54 | eval_dataset=eval_dataset, 55 | eval_collator=None, 56 | ) 57 | 58 | if getattr(config, "load_from", None): 59 | load_model_weights(trainer.model, config.load_from) 60 | 61 | print("Training Start") 62 | trainer.train( 63 | resume_from_checkpoint=get_last_checkpoint(train_args.output_dir) 64 | if train_args.resume 65 | else None 66 | ) 67 | trainer.save_state() 68 | trainer.save_model(output_dir=os.path.join(train_args.output_dir, "training_end")) 69 | print("All Finished") 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | --------------------------------------------------------------------------------