├── .gitignore ├── README.md ├── build.sh ├── config ├── datasets │ ├── eval │ │ ├── AMBER_desc.yaml │ │ ├── CLEVR_val1k_type_opt.yaml │ │ ├── EmbSpa_test_opt.yaml │ │ ├── GQA.yaml │ │ ├── MMBench_DEV_opt.yaml │ │ ├── POPE_adversarial.yaml │ │ ├── SEED_opt.yaml │ │ ├── VSR_TEST.yaml │ │ ├── VStar.yaml │ │ ├── Wino.yaml │ │ ├── clevr_ref.yaml │ │ ├── refcocog_test-u.yaml │ │ └── refcocog_val-u.yaml │ ├── stage1_alignment.yaml │ ├── stage2_interleaved.yaml │ └── stage3_instruct.yaml ├── deepspeed │ ├── config_zero2.json │ ├── config_zero2_offload.json │ ├── config_zero3.json │ └── config_zero3_offload.json └── experiments │ ├── stage1_alignment.yaml │ ├── stage2_interleaved.yaml │ └── stage3_instruct.yaml ├── constants.py ├── conversation.py ├── eval ├── Evaluation.md ├── __init__.py ├── commands │ ├── run_metric.sh │ └── test_all_benchmark.sh ├── eval_tools │ ├── clevr.py │ ├── convert_res_to_amber.py │ ├── convert_res_to_gqa.py │ ├── convert_res_to_gqa_llava.py │ ├── convert_res_to_mmbench.py │ ├── gqa.py │ ├── m4c_evaluator.py │ ├── mmbench.py │ ├── mp3d.py │ ├── pope.py │ ├── refcoco.py │ ├── seed.py │ ├── vsr.py │ └── vstar.py ├── evaluate_benchmark.py └── merge_benchmark.py ├── figs ├── model_arch.png └── sample_input.jpg ├── locals └── datasets │ ├── __init__.py │ ├── dataloader.py │ ├── eval │ ├── qa.py │ └── short_qa.py │ ├── image_caption │ ├── cc3m.py │ ├── coco.py │ ├── flicker30k.py │ ├── fusecap.py │ ├── grit.py │ ├── lvis_gpt4v.py │ └── mmc4.py │ ├── image_edit │ ├── edit_dataset │ │ └── zipped_datasets.py │ ├── low_level │ │ ├── clwd.py │ │ ├── gopro.py │ │ ├── reds.py │ │ └── sidd.py │ ├── prompt │ │ ├── color_list_train_small.txt │ │ ├── prompt_deblur.txt │ │ ├── prompt_denoise.txt │ │ ├── prompt_dewatermark.txt │ │ ├── prompt_pose.txt │ │ └── prompt_seg.txt │ └── seg │ │ ├── coco_stuff.py │ │ ├── grefcoco.py │ │ ├── grefcoco_seg.py │ │ ├── refcoco.py │ │ └── refcoco_seg.py │ ├── multimodal_tasks │ ├── cot_qa.py │ ├── llava_R.py │ ├── llava_academic.py │ ├── lvis_instruct4v.py │ ├── m3it.py │ ├── object_detect.py │ ├── refcoco.py │ ├── sharegpt4v.py │ ├── single_image_base.py │ ├── spatial.py │ └── svit.py │ ├── preprocessor.py │ ├── prompts │ ├── prompt_captioning_long.txt │ ├── prompt_captioning_short.txt │ ├── prompt_kosmosg.txt │ ├── prompt_ref_i2t.txt │ ├── prompt_ref_t2i.txt │ └── prompt_txt2img.txt │ ├── text │ ├── sharegpt.py │ ├── text_data_base.py │ ├── txt_cot.py │ └── ultrachat.py │ ├── text2image │ ├── kosmosg.py │ └── midjourney.py │ └── utils │ ├── box_utils.py │ └── zip_manager.py ├── model ├── front_projector │ ├── Qformer.py │ └── builder.py ├── language_model │ ├── volcano_base.py │ ├── volcano_llama.py │ └── volcano_mistral.py ├── load_model.py └── vision_encoder │ ├── builder.py │ ├── clip_encoder.py │ ├── eva_vit.py │ └── eva_vit_emu.py ├── requirements.txt ├── train_volcano.py ├── utils ├── __init__.py ├── __init__.pyc ├── count_line.py ├── eval_util.py ├── format_utils.py ├── gqa_inpaint.py ├── llava_flash_attn.py ├── logger.py ├── logger.pyc ├── time_check.py └── util.py └── vocot_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌋VoCoT: Unleashing Visually Grounded Multi-Step Reasoning in Large Multi-Modal Models [[Paper](https://arxiv.org/abs/2405.16919)] 2 |

3 | 4 |

5 | 6 | Exploring visually CoT reasoning in large multi-modal models with VoCoT! 7 | 8 | [Zejun Li*](https://github.com/Junction4Nako), [Ruipu Luo*](https://github.com/RupertLuo), Jiwen Zhang, Minghui Qiu, Zhongyu Wei (*Equal Contribution) 9 | 10 | We propose Visually grounded object-centric Chain-of-Thoughts (VoCoT) to support effective and reliable multi-step reasoning in large multi-modal models. For more details, please refer to our paper. 11 | 12 | In this repository, we will release: 13 | - The constructed VoCoT-Instruct data that can be use to train your multi-modal models to reason in the VoCoT format. 14 | - VolCano model with the ability to perform multi-step reasoning in VoCoT. 15 | - Training scripts utilized to train VolCano. 16 | - Evaluation datasets and scripts used in our paper. 17 | 18 | ## Contents 19 | - [Getting Started](#getting-started) 20 | - [Data](#data) 21 | - [Model Weights](#model-weights) 22 | - [Train](#train) 23 | - [Evaluation](#evaluation) 24 | 25 | ## Getting Started 26 | 27 | ### Install 28 | 29 | 1. Clone this repository 30 | ```bash 31 | git clone https://github.com/RupertLuo/VoCoT.git 32 | cd VoCoT 33 | ``` 34 | 35 | 2. Install packages 36 | ```bash 37 | conda create -n vocot python=3.9 -y 38 | conda activate vocot 39 | pip3 install --upgrade pip 40 | pip3 install -r requirements.txt 41 | pip3 install flash-attn --no-build-isolation 42 | sudo apt-get install python3-tk -y 43 | ``` 44 | 45 | ### Quick Start 46 | 47 | ```python 48 | from model.load_model import load_model, infer 49 | from PIL import Image 50 | 51 | # loading the model 52 | model_path = 'luoruipu1/Volcano-7b' 53 | model, preprocessor = load_model(model_path, precision='fp16') 54 | 55 | # perform reasoning, activate VoCoT by passing cot=True 56 | input_image = Image.open('figs/sample_input.jpg') 57 | response_1 = infer(model, preprocessor, input_image, 'Is there a event "the cat is below the bed" in this image?', cot=True) 58 | response_2 = infer(model, preprocessor, input_image, 'Why is the cat on the bed?', cot=True) 59 | response_3 = infer(model, preprocessor, input_image, 'Describe the image.', cot=True) 60 | print('response 1: ', response_1[0]) 61 | print('response 2: ', response_2[0]) 62 | print('response 3: ', response_3[0]) 63 | ``` 64 | Notice: in the default setting, the output coordinates is the box in the image which is expanded to square, not the original image. 65 | 66 | ## Data 67 | 68 | For users who want to use VoCoT-Instruct to train their own models, we provide a integrated json file, [VoCoT-Instruct-80K](https://huggingface.co/datasets/luoruipu1/VoCoT/blob/main/VoCoT-80K_integrated.json) following the conversation format of LLaVA (Notice that all coordinates are for the images that are expanded to square). 69 | 70 | If you would like to follow the training of VolCano in this paper, please use the separate json files in the [raw_data](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/raw_data) for efficient dataset management. 71 | 72 | For the corresponding images, please visit [GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html), [COCO](https://cocodataset.org/), and [LVIS](https://www.lvisdataset.org/) to download the images. 73 | 74 | ## Model Weights 75 | 76 | The VolCano model is based on Mistral-Instruct-v0.2-7B and CLIP-14/L, the connection the trained weights are released [here](https://huggingface.co/luoruipu1/Volcano-7b/tree/main). The architecture is illustrated below and details are included in our paper. 77 | 78 |

79 |
80 | The architecture of VolCano. 81 |

82 | 83 | ## Train 84 | 85 | In each stage, you prepare the data and pre-trained checkpoints with the help of following instructions: 86 | 87 | ### Prepare the Data 88 | 89 | The datasets are managed with yaml files in [config/datasets](./config/datasets/). In each yaml, you need to make sure all the paths with "PATH/TO" in them are correctly prepared. 90 | 91 | For datasets that are either introduced, modified, or filtered in this paper, e.g. VoCoT-Instruct, the subset of GRIT and MMC4, we provide the meta data in [here](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/pretrain). 92 | 93 | For public datasets like RefCOCO, LLaVA, and ALLaVA, please refer to the corresponding websites to obtain the data. 94 | If it is not clear where to obtain the dataset, feel free to contact us. 95 | 96 | ### Conduct Training 97 | 98 | We manage the experiment settings with yaml files in [config/experiments](./config/experiments/). In each yaml, set the paths with "PATH/TO" to your correct paths. 99 | 100 | Every stage can be launched with (replace the yaml path with corresponding paths): 101 | 102 | ```bash 103 | # modify the torchrun config to fit your machine 104 | torchrun --nproc_per_node=8 train_volcano.py --conf config/experiments/stage1_alignment.yaml 105 | ``` 106 | 107 | ## Evaluation 108 | Please see [Evaluation](./eval/Evaluation.md) for details about the evaluation datasets and evaluation scripts. 109 | 110 | ## Acknowledgement 111 | We thank the following open-source resources which we referenced during the development of VoCoT. 112 | 113 | - [LLaVA](https://github.com/haotian-liu/LLaVA): our codebase is built on LLaVA and LLaVA-Instruct is adopted during the training. 114 | - [Shikra](https://github.com/shikras/shikra): we follow Shikra for the construction of GQA Type-1 VoCoT-Instruct data. 115 | - [InstructDiffusion](https://github.com/cientgu/InstructDiffusion): we referenced InstructDiffusion codebase to manage multiple datasets. 116 | 117 | 118 | ## Citation 119 | 120 | If you find our code, data, or model useful during your work, please consider citing our work: 121 | ```bibtex 122 | @article{li2024vocot, 123 | title={VoCoT: Unleashing Visually Grounded Multi-Step Reasoning in Large Multi-Modal Models}, 124 | author={Li, Zejun and Luo, Ruipu and Zhang, Jiwen and Qiu, Minghui and Wei, Zhongyu}, 125 | journal={arXiv preprint arXiv:2405.16919}, 126 | year={2024} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | pip3 install --upgrade pip 2 | pip3 install -r requirements.txt 3 | pip3 install flash-attn --no-build-isolation 4 | sudo apt-get install python3-tk -y -------------------------------------------------------------------------------- /config/datasets/eval/AMBER_desc.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.short_qa.AMBERDataset 3 | params: 4 | path: PATH/TO/AMBER/AMBER/data/query/query_generative.json 5 | base_path: PATH/TO/AMBER/image 6 | describe: True -------------------------------------------------------------------------------- /config/datasets/eval/CLEVR_val1k_type_opt.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.CLEVRDataset 3 | params: 4 | path: PATH/TO/CLEVR/CLEVR_v1.0/clevr_val_4choices_1k_per_type.json 5 | image_dir: PATH/TO/CLEVR_v1.0/images/val/ -------------------------------------------------------------------------------- /config/datasets/eval/EmbSpa_test_opt.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.EmbSpatialDataset 3 | params: 4 | path: PATH/TO/embspatial_bench_v1.json 5 | image_dir: PATH/TO/emb_spatial/ -------------------------------------------------------------------------------- /config/datasets/eval/GQA.yaml: -------------------------------------------------------------------------------- 1 | - gqa: 2 | target: locals.datasets.eval.short_qa.GQADataset 3 | params: 4 | path: PATH/TO/GQA/testdev_balanced_questions.json 5 | base_path: PATH/TO/GQA/images/ -------------------------------------------------------------------------------- /config/datasets/eval/MMBench_DEV_opt.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.MMBenchOptDataset 3 | params: 4 | path: PATH/TO/mmbench/MMBench_DEV_EN_legacy.tsv 5 | single_pred_prompt: False -------------------------------------------------------------------------------- /config/datasets/eval/POPE_adversarial.yaml: -------------------------------------------------------------------------------- 1 | - pope_adv: 2 | target: locals.datasets.eval.short_qa.POPEDataset 3 | params: 4 | path: PATH/TO/coco_pope_adversarial.jsonl 5 | base_path: PATH/TO/COCO2015/images/ -------------------------------------------------------------------------------- /config/datasets/eval/SEED_opt.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.SEEDOptionDataset 3 | params: 4 | path: PATH/TO/SEED-Bench/SEED-Bench.json 5 | image_dir: PATH/TO/SEED-Bench/SEED-Bench-image -------------------------------------------------------------------------------- /config/datasets/eval/VSR_TEST.yaml: -------------------------------------------------------------------------------- 1 | - mme: 2 | target: locals.datasets.eval.short_qa.VSRDataset 3 | params: 4 | path: PATH/TO/VSR/zeroshot/test.jsonl 5 | base_path: PATH/TO/COCO2017/ -------------------------------------------------------------------------------- /config/datasets/eval/VStar.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.VStarDataset 3 | params: 4 | path: PATH/TO/VStar/meta.json 5 | image_dir: PATH/TO/VStar/vstar_hub_dl -------------------------------------------------------------------------------- /config/datasets/eval/Wino.yaml: -------------------------------------------------------------------------------- 1 | - mmvet: 2 | target: locals.datasets.eval.qa.WinoTextDataset 3 | params: 4 | option_in_context: False 5 | -------------------------------------------------------------------------------- /config/datasets/eval/clevr_ref.yaml: -------------------------------------------------------------------------------- 1 | - refcoco: 2 | target: locals.datasets.multimodal_tasks.refcoco.ClevrRefDataset 3 | params: 4 | path: PATH/TO/clevr_ref/meta_1item_2k.json 5 | image_path: PATH/TO/clevr_ref+_1.0/images/val -------------------------------------------------------------------------------- /config/datasets/eval/refcocog_test-u.yaml: -------------------------------------------------------------------------------- 1 | - refcoco: 2 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCOEvalDataset 3 | params: 4 | path: PATH/TO/Refcoco 5 | dataset_name: refcocog 6 | split: test 7 | split_by: umd 8 | image_path: PATH/TO/COCO2015/images/ 9 | task_mode: t2i -------------------------------------------------------------------------------- /config/datasets/eval/refcocog_val-u.yaml: -------------------------------------------------------------------------------- 1 | - refcoco: 2 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCOEvalDataset 3 | params: 4 | path: /mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/Refcoco 5 | dataset_name: refcocog 6 | split: val 7 | split_by: umd 8 | image_path: PATH/TO/COCO2015/images/ 9 | task_mode: t2i -------------------------------------------------------------------------------- /config/datasets/stage1_alignment.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | target: locals.datasets.DataModuleFromConfig 3 | params: 4 | batch_size: 64 5 | num_workers: 4 6 | wrap: True 7 | train: 8 | - llava_pretrain_i2t: 9 | target: locals.datasets.image_caption.cc3m.FilteredCC3MI2TDataset 10 | params: 11 | path: /PATH/TO/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json 12 | image_folder: /PATH/TO/LLaVA-Pretrain/images 13 | raw_image: True 14 | output_mode: text 15 | shuffle: False -------------------------------------------------------------------------------- /config/datasets/stage2_interleaved.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | target: locals.datasets.DataModuleFromConfig 3 | params: 4 | batch_size: 64 5 | num_workers: 4 6 | wrap: True 7 | train: 8 | - allava_laion: 9 | target: locals.datasets.multimodal_tasks.llava_academic.ALLaVACaptionDataset 10 | params: 11 | path: PATH/TO/allava_laion/ALLaVA-Caption-LAION-4V.json 12 | image_folder: PATH/TO/ALLaVA/images 13 | raw_image: True 14 | output_mode: text 15 | shuffle: False 16 | expand2square: True 17 | - allava_vflan: 18 | target: locals.datasets.multimodal_tasks.llava_academic.ALLaVACaptionDataset 19 | params: 20 | path: PATH/TO/allava_vflan/ALLaVA-Caption-VFLAN-4V.json 21 | image_folder: PATH/TO/ALLaVA/images 22 | raw_image: True 23 | output_mode: text 24 | shuffle: False 25 | expand2square: True 26 | - mmc4: 27 | target: locals.datasets.image_caption.mmc4.FilteredMMC4Dataset 28 | params: 29 | path: PATH/TO/filter_mmc4_meta_with_img_abs_path_890k.jsonl 30 | avoid_image_gen: True 31 | expand2square: True 32 | - grit: 33 | target: locals.datasets.image_caption.grit.GriTDataset 34 | params: 35 | path: PATH/TO/clip_filtered_756K_grit.json 36 | image_folder: PATH/TO/GRIT/images/ 37 | raw_image: True 38 | output_mode: text 39 | shuffle: True 40 | phrase_key: noun_chunks 41 | avoid_image_gen: True 42 | phrase_format: 'text' 43 | phrase_prec: 3 44 | object_format: 'representation' 45 | expand2square: True 46 | - flickr30kentities: 47 | target: locals.datasets.image_caption.flicker30k.FlickrDataset 48 | params: 49 | path: PATH/TO/CWB_flickr30k_train.jsonl 50 | image_folder: PATH/TO/Flicker30K/flickr30k-images/ 51 | avoid_image_gen: True 52 | phrase_format: 'text' 53 | phrase_prec: 3 54 | object_format: 'representation' 55 | expand2square: True -------------------------------------------------------------------------------- /config/datasets/stage3_instruct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | target: locals.datasets.DataModuleFromConfig 3 | params: 4 | batch_size: 64 5 | num_workers: 4 6 | wrap: True 7 | train: 8 | - llava_academic: 9 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaAcademicDataset 10 | params: 11 | path: PATH/TO/llava_v1_5_mix665k_norefcoco.json # removing refcoco from llava 665K 12 | image_folder: PATH/TO/LLaVA_images 13 | raw_image: True 14 | output_mode: conversation 15 | avoid_image_gen: True 16 | min_size: 50 17 | phrase_format: 'text' 18 | phrase_prec: 3 19 | expand2square: True 20 | object_format: 'representation' 21 | - refcoco: 22 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset 23 | params: 24 | path: PATH/TO/Refcoco 25 | dataset_name: refcoco 26 | split: train 27 | image_path: PATH/TO/COCO2015/images/ 28 | task_mode: both 29 | avoid_image_gen: True 30 | phrase_format: 'text' 31 | phrase_prec: 3 32 | expand2square: True 33 | object_format: 'representation' 34 | - refcoco+: 35 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset 36 | params: 37 | path: PATH/TO/Refcoco 38 | dataset_name: refcoco+ 39 | split: train 40 | image_path: PATH/TO/COCO2015/images/ 41 | task_mode: both 42 | avoid_image_gen: True 43 | phrase_format: 'text' 44 | phrase_prec: 3 45 | expand2square: True 46 | object_format: 'representation' 47 | - refcocog: 48 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset 49 | params: 50 | path: PATH/TO/Refcoco 51 | dataset_name: refcocog 52 | split: train 53 | split_by: umd 54 | image_path: PATH/TO/COCO2015/images/ 55 | task_mode: both 56 | avoid_image_gen: True 57 | phrase_format: 'text' 58 | phrase_prec: 3 59 | expand2square: True 60 | object_format: 'representation' 61 | - grefcoco: 62 | target: locals.datasets.multimodal_tasks.refcoco.GRefCOCODataset 63 | params: 64 | path: PATH/TO/Refcoco 65 | split: train 66 | image_path: PATH/TO/COCO2015/images/ 67 | task_mode: both 68 | avoid_image_gen: True 69 | phrase_format: 'text' 70 | phrase_prec: 3 71 | expand2square: True 72 | object_format: 'representation' 73 | - shikra_cot_gen: 74 | target: locals.datasets.multimodal_tasks.cot_qa.CoTQADataset 75 | params: 76 | path: PATH/TO/Shikra/GPT4GEN_BoxCoT_train.jsonl # See Shikra data 77 | image_path: PATH/TO/Flicker30K/flickr30k-images/ 78 | avoid_image_gen: True 79 | phrase_format: 'text' 80 | phrase_prec: 3 81 | expand2square: True 82 | object_format: 'representation' 83 | further_instruct: True 84 | - shikra_rd: 85 | target: locals.datasets.multimodal_tasks.cot_qa.CoTQADataset 86 | params: 87 | path: PATH/TO/Shikra/GPT4GEN_RD_BoxCoT_train.jsonl # See Shikra data 88 | image_path: PATH/TO/Flicker30K/flickr30k-images/ 89 | avoid_image_gen: True 90 | phrase_format: 'text' 91 | phrase_prec: 3 92 | expand2square: True 93 | object_format: 'representation' 94 | further_instruct: False 95 | - cot_gqa: 96 | target: locals.datasets.multimodal_tasks.cot_qa.GQACoTDataset 97 | params: 98 | path: PATH/TO/raw_data/type1_gqa_raw.jsonl 99 | image_path: PATH/TO/GQA/images 100 | avoid_image_gen: True 101 | phrase_format: 'text' 102 | phrase_prec: 3 103 | expand2square: True 104 | object_format: 'representation' 105 | further_instruct: True 106 | sample_weight: 0.75 107 | - llava_QA2T: 108 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaQA2TDataset 109 | params: 110 | path: PATH/TO/raw_data/type2_iqa2t_raw.json 111 | raw_image: True 112 | output_mode: conversation 113 | avoid_image_gen: True 114 | min_size: 50 115 | phrase_format: 'text' 116 | phrase_prec: 3 117 | expand2square: True 118 | object_format: 'representation' 119 | - lvis_I2QTA: 120 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaI2QTADataset 121 | params: 122 | path: PATH/TO/raw_data/type3_i2qta_raw.json 123 | raw_image: True 124 | output_mode: conversation 125 | avoid_image_gen: True 126 | min_size: 50 127 | phrase_format: 'text' 128 | phrase_prec: 3 129 | expand2square: True 130 | object_format: 'representation' 131 | block_invalid: True -------------------------------------------------------------------------------- /config/deepspeed/config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 2, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 5e8, 18 | "overlap_comm": true, 19 | "reduce_scatter": true, 20 | "reduce_bucket_size": 5e8, 21 | "contiguous_gradients": true 22 | }, 23 | 24 | "gradient_accumulation_steps": "auto", 25 | "gradient_clipping": "auto", 26 | "steps_per_print": 2000, 27 | "train_batch_size": "auto", 28 | "train_micro_batch_size_per_gpu": "auto", 29 | "wall_clock_breakdown": false 30 | } -------------------------------------------------------------------------------- /config/deepspeed/config_zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 2, 16 | "offload_optimizer": { 17 | "device": "cpu", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 5e8, 26 | "overlap_comm": true, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 5e8, 29 | "contiguous_gradients": true 30 | }, 31 | 32 | "gradient_accumulation_steps": "auto", 33 | "gradient_clipping": "auto", 34 | "steps_per_print": 2000, 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu": "auto", 37 | "wall_clock_breakdown": false 38 | } -------------------------------------------------------------------------------- /config/deepspeed/config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /config/deepspeed/config_zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /config/experiments/stage1_alignment.yaml: -------------------------------------------------------------------------------- 1 | project_name: volcano 2 | run_name: volcano_stage2 3 | # Whether to make the system prompt a mask in the label, and others do not mask 4 | only_mask_system: False 5 | # system prompt style 6 | conv_mode: v1 7 | # wether lora 8 | lora_enable: False 9 | # wether multimodal 10 | is_multimodal: True 11 | 12 | freeze_backbone: True 13 | 14 | # weight path 15 | model_path: mistralai/Mistral-7B-Instruct-v0.2 16 | vision_encoder: openai/clip-vit-large-patch14-336 17 | vision_encoder_path: openai/clip-vit-large-patch14-336 18 | skip_vision_encoder_load: False 19 | front_projector_type: mlp2x_gelu 20 | num_query_token: 32 21 | avoid_generator: True # do not use generator in this stage 22 | output_dir: /PATH/TO/STAGE1_CKPT 23 | behind_projector: linear 24 | flash_attn: True 25 | # dataset config 26 | data_config_path: config/datasets/stage1_aligment.yaml 27 | expand_to_square: True 28 | remove_unused_columns: False 29 | regression_weight: 1.0 30 | tokenizer_model_max_length: 2048 31 | model_max_length: 2048 32 | extend_loc_vocabulary: False 33 | use_mistral: True 34 | 35 | num_train_epochs: 1 36 | per_device_train_batch_size: 16 37 | save_strategy: 'steps' 38 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no'] 39 | save_steps: 6000 40 | learning_rate: 1e-5 41 | gradient_checkpointing: True 42 | # wether do fast epoch 43 | fast_epoch: False 44 | 45 | # whether to compute diffusion loss 46 | compute_diffusion_loss: False 47 | 48 | bf16: True 49 | fp16: False 50 | tf32: False 51 | per_device_eval_batch_size: 1 52 | gradient_accumulation_steps: 1 53 | evaluation_strategy: "no" 54 | save_total_limit: 3 55 | weight_decay: 0. 56 | warmup_ratio: 0.0 57 | lr_scheduler_type: cosine 58 | logging_steps: 1 59 | model_max_length: 2048 60 | adam_beta1: 0.9 61 | adam_beta2: 0.95 62 | deepspeed: config/deepspeed/config_zero2.json 63 | dataloader_num_workers: 4 64 | # report_to: wandb 65 | is_training: True -------------------------------------------------------------------------------- /config/experiments/stage2_interleaved.yaml: -------------------------------------------------------------------------------- 1 | project_name: volcano 2 | run_name: volcano_stage2 3 | # Whether to make the system prompt a mask in the label, and others do not mask 4 | only_mask_system: False 5 | # system prompt style 6 | conv_mode: v1 7 | # wether lora 8 | lora_enable: False 9 | # wether multimodal 10 | is_multimodal: True 11 | 12 | freeze_backbone: False 13 | 14 | # weight path 15 | model_path: mistralai/Mistral-7B-Instruct-v0.2 16 | vision_encoder: openai/clip-vit-large-patch14-336 17 | vision_encoder_path: openai/clip-vit-large-patch14-336 18 | skip_vision_encoder_load: False 19 | front_projector_type: mlp2x_gelu 20 | num_query_token: 32 21 | avoid_generator: True # do not use generator in this stage 22 | output_dir: PATH/TO/STAGE2_CKPT 23 | behind_projector: linear 24 | flash_attn: True 25 | # dataset config 26 | data_config_path: config/datasets/stage2_interleaved.yaml 27 | expand_to_square: True 28 | remove_unused_columns: False 29 | regression_weight: 1.0 30 | tokenizer_model_max_length: 2048 31 | model_max_length: 2048 32 | extend_loc_vocabulary: False 33 | use_mistral: True 34 | stage1_ckpt: PATH/TO/STAGE1_CKPT_PROJECTION # please only contain the projection layer 35 | 36 | num_train_epochs: 1 37 | per_device_train_batch_size: 16 38 | save_strategy: 'steps' 39 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no'] 40 | save_steps: 6000 41 | learning_rate: 1e-5 42 | gradient_checkpointing: True 43 | # wether do fast epoch 44 | fast_epoch: False 45 | 46 | # whether to compute diffusion loss 47 | compute_diffusion_loss: False 48 | 49 | bf16: True 50 | fp16: False 51 | tf32: False 52 | per_device_eval_batch_size: 1 53 | gradient_accumulation_steps: 1 54 | evaluation_strategy: "no" 55 | save_total_limit: 3 56 | weight_decay: 0. 57 | warmup_ratio: 0.0 58 | lr_scheduler_type: cosine 59 | logging_steps: 1 60 | model_max_length: 2048 61 | adam_beta1: 0.9 62 | adam_beta2: 0.95 63 | deepspeed: config/deepspeed/config_zero2.json 64 | dataloader_num_workers: 4 65 | # report_to: wandb 66 | is_training: True -------------------------------------------------------------------------------- /config/experiments/stage3_instruct.yaml: -------------------------------------------------------------------------------- 1 | project_name: volcano 2 | run_name: volcano_stage3 3 | # Whether to make the system prompt a mask in the label, and others do not mask 4 | only_mask_system: False 5 | # system prompt style 6 | conv_mode: v1 7 | # wether lora 8 | lora_enable: False # no lora in the stage-1 9 | # wether multimodal 10 | is_multimodal: True 11 | 12 | freeze_backbone: False 13 | 14 | # weight path 15 | model_path: PATH/TO/STAGE2_CKPT 16 | vision_encoder: openai/clip-vit-large-patch14-336 17 | vision_encoder_path: openai/clip-vit-large-patch14-336 18 | skip_vision_encoder_load: False 19 | front_projector_type: mlp2x_gelu 20 | num_query_token: 32 21 | avoid_generator: True # do not use generator in this stage 22 | output_dir: PATH/TO/STAGE3_CKPT 23 | behind_projector: linear 24 | flash_attn: True 25 | # dataset config 26 | data_config_path: config/datasets/stage3_instruct.yaml 27 | expand_to_square: True 28 | remove_unused_columns: False 29 | regression_weight: 1.0 30 | tokenizer_model_max_length: 3072 31 | model_max_length: 3072 32 | extend_loc_vocabulary: False 33 | use_mistral: True 34 | 35 | num_train_epochs: 1 36 | per_device_train_batch_size: 16 37 | save_strategy: 'steps' 38 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no'] 39 | save_steps: 3000 40 | learning_rate: 1e-5 41 | gradient_checkpointing: True 42 | # wether do fast epoch 43 | fast_epoch: False 44 | 45 | # whether to compute diffusion loss 46 | compute_diffusion_loss: False 47 | 48 | bf16: True 49 | fp16: False 50 | tf32: False 51 | per_device_eval_batch_size: 1 52 | gradient_accumulation_steps: 1 53 | evaluation_strategy: "no" 54 | save_total_limit: 3 55 | weight_decay: 0. 56 | warmup_ratio: 0.0 57 | lr_scheduler_type: cosine 58 | logging_steps: 1 59 | model_max_length: 3072 60 | adam_beta1: 0.9 61 | adam_beta2: 0.95 62 | deepspeed: config/deepspeed/config_zero2.json 63 | dataloader_num_workers: 4 64 | # report_to: wandb 65 | is_training: True -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | IMG_TOKEN_NUM = 8 6 | ALL_IMG_TOKENS = [f"[IMG{i}]" for i in range(IMG_TOKEN_NUM)] 7 | ALL_IMG_TOKENS_STR = '' # ""# "".join(ALL_IMG_TOKENS) 8 | 9 | # Location related tokens 10 | LOC_TOKEN_NUM = 256 11 | ALL_LOC_TOKENS = ["[LOC{}]".format(i+1) for i in range(LOC_TOKEN_NUM)] 12 | 13 | USE_PREFIX_TUNING = False 14 | USE_LORA = False 15 | USE_CFG = True 16 | IGNORE_TOKEN_ID = -100 17 | IMAGE_TOKEN_INDEX = -200 18 | 19 | 20 | PRECISION = torch.bfloat16 21 | TRAINABLE_PRECISION = torch.float32 22 | 23 | IGNORE_INDEX = -100 24 | DEFAULT_GRD_TOKEN = "" 25 | DEFAULT_BOP_TOKEN = "" # begin of phrase modified from 3/11 26 | DEFAULT_EOP_TOKEN = "" # end of phrase modified from 3/11 27 | DEFAULT_BOC_TOKEN = "" # begin of coordinates 28 | DEFAULT_EOC_TOKEN = "" # end of coordinates 29 | DEFAULT_SEP_TOKEN = " and"# "" modified from 3/11 30 | DEFAULT_PAD_TOKEN = "[PAD]" 31 | DEFAULT_EOS_TOKEN = "" 32 | DEFAULT_BOS_TOKEN = "" 33 | DEFAULT_UNK_TOKEN = "" 34 | 35 | # begin of image 36 | DEFAULT_BOI_TOKEN = "" 37 | # end of image 38 | DEFAULT_EOI_TOKEN = '' 39 | # default image token 40 | DEFAULT_IMG_TOKEN = '' 41 | COT_ACTIVATION = 'Answer the question and include the reasoning proess. Locate key objects and provide bounding boxes in your thoughts.' 42 | COT_ACTIVATION_TXT = 'Answer the question and include the reasoning proess.' -------------------------------------------------------------------------------- /eval/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation of VolCano 2 | 3 | ## Data Prepare 4 | For evaluation, we also utilize the yaml file to manage the evaluation datasets. 5 | 6 | For CLEVR and CLEVR-ref, we provide the constructed evaluation meta data in [here](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/eval). For other dataets, please refer to the original website for the data. 7 | 8 | With the prepared datasets, please set the correct paths in those [config files](../config/datasets/eval/). 9 | 10 | ## Run Evaluation 11 | 12 | We provide the evaluation scripts in [test_all_benchmark.sh](./commands/test_all_benchmark.sh), you can modify the model path and run the entire script or use a part of it. 13 | 14 | ## Metric Computation 15 | 16 | Similar to the evaluation, all the metrics can be computed offline with [run_metric.sh](./commands/run_metric.sh). For GQA and AMBER, the output is converted into appropriate format. You need to further compute the metric, please refer to [LLaVA_for_GQA](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md#gqa) and [AMBER](https://github.com/junyangwang0410/AMBER) for further instruction. 17 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/eval/__init__.py -------------------------------------------------------------------------------- /eval/commands/run_metric.sh: -------------------------------------------------------------------------------- 1 | ##################### total setting=2,4,6,7 2 | function run_all(){ 3 | model_name=/mnt/bn/yangmin-priv/luoruipu/checkpoints/LLaVA-clip336px-obj-represent-Mistral-1e-5-3072-instruct_llava+shikraCoT+GPTQTA-Block+lvis-cot/ 4 | 5 | 6 | store_model_name=llava_mistral_instruct_simplified_image_llava_shikraCoT75+GPTQTA-qa2t 7 | 8 | ## dataset setting 9 | function mmbench(){ 10 | ##### mci 11 | dataset_name=mmbench 12 | dataset_config=config/datasets/eval/MMBench_DEV_opt.yaml 13 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/ 14 | echo "========MMBench Result==========" 15 | python3 eval/eval_tools/mmbench.py --config ${dataset_config} --result ${output_dir}/MMBench_DEV_opt.json 16 | } 17 | 18 | function seed(){ 19 | ##### mci 20 | dataset_name=seed 21 | dataset_config=config/datasets/eval/SEED_opt.yaml 22 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/ 23 | echo "========SEED Result==========" 24 | python3 eval/eval_tools/seed.py --config ${dataset_config} --result ${output_dir}/SEED_opt.json 25 | } 26 | 27 | function clevr(){ 28 | ##### mci 29 | dataset_name=clevr 30 | dataset_config=config/datasets/eval/CLEVR_val1k_type_opt.yaml 31 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/ 32 | echo "========CLEVR Result==========" 33 | python3 eval/eval_tools/clevr.py --config config/datasets/eval/CLEVR_val1k_type_opt.yaml --result ${output_dir}/CLEVR_val1k_type_opt.json 34 | } 35 | 36 | 37 | function embspatial(){ 38 | ##### mci 39 | dataset_name=emb_spa 40 | dataset_config=config/datasets/eval/EmbSpa_test_opt.yaml 41 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/ 42 | echo "========EmbSpatial Result===========" 43 | python3 eval/eval_tools/mp3d.py --config config/datasets/eval/EmbSpa_test_opt.yaml --result ${output_dir}/EmbSpa_test_opt.json 44 | } 45 | 46 | function vsr(){ 47 | ##### mci 48 | dataset_name=vsr 49 | dataset_config=config/datasets/eval/VSR_TEST.yaml 50 | output_dir=output/${dataset_name}/${store_model_name}/cot/ 51 | echo "========VSR Result===========" 52 | python3 eval/eval_tools/vsr.py --config config/datasets/eval/VSR_TEST.yaml --data ${output_dir}/VSR_TEST.json 53 | } 54 | 55 | function pope(){ 56 | ##### mci 57 | dataset_name=pope 58 | dataset_config=config/datasets/eval/POPE_adversarial.yaml 59 | output_dir=output/${dataset_name}/${store_model_name}/cot/ 60 | echo "========POPE Result===========" 61 | python3 eval/eval_tools/pope.py --data ${output_dir}/POPE_adversarial.json 62 | } 63 | 64 | function vstar(){ 65 | ##### mci 66 | dataset_name=vstar 67 | dataset_config=config/datasets/eval/VStar.yaml 68 | output_dir=output/${dataset_name}/${store_model_name}/cot/ 69 | echo "========VSTAR Result===========" 70 | python3 eval/eval_tools/vstar.py --result ${output_dir}/VStar.json 71 | } 72 | 73 | function wino(){ 74 | ##### mci 75 | dataset_name=wino 76 | dataset_config=config/datasets/eval/Wino.yaml 77 | output_dir=output/${dataset_name}/${store_model_name}/cot_no_instruct/ 78 | echo "========WINO-txt Result===========" 79 | python3 eval/eval_tools/vstar.py --result ${output_dir}/Wino.json 80 | } 81 | 82 | function amber(){ 83 | dataset_name=amber 84 | dataset_config=config/datasets/eval/AMBER_desc.yaml 85 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/ 86 | echo "========AMBER Result Need Further Evaluation===========" 87 | python3 eval/eval_tools/convert_res_to_amber.py \ 88 | --src ${output_dir}/AMBER_desc.json \ 89 | --tgt ${output_dir}/AMBER_desc_eval.json \ 90 | --desc 91 | } 92 | 93 | function gqa(){ 94 | dataset_name=gqa 95 | dataset_config=config/datasets/eval/GQA.yaml 96 | output_dir=output/${dataset_name}/${store_model_name}/cot/ 97 | echo "========GQA Result Need Further Evaluation===========" 98 | python3 eval/eval_tools/convert_res_to_gqa.py \ 99 | --src ${output_dir}/GQA.json \ 100 | --dst ${output_dir}/testdev_balanced_predictions.json 101 | } 102 | 103 | function refcocog_test(){ 104 | ##### mci 105 | dataset_name=refcoco 106 | dataset_config=config/datasets/eval/refcocog_test-u.yaml 107 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/ 108 | echo "========RefCOCOg test Result===========" 109 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/refcocog_test-u.json 110 | } 111 | function refcocog_val(){ 112 | ##### mci 113 | dataset_name=refcoco 114 | dataset_config=config/datasets/eval/refcocog_val-u.yaml 115 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/ 116 | echo "========RefCOCOg val Result===========" 117 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/refcocog_val-u.json 118 | } 119 | function clevr_ref(){ 120 | dataset_name=refcoco 121 | dataset_config=config/datasets/eval/clevr_ref.yaml 122 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/ 123 | echo "========CLEVR REF Result===========" 124 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/clevr_ref.json 125 | } 126 | 127 | mmbench 128 | seed 129 | embspatial 130 | clevr 131 | wino 132 | vsr 133 | pope 134 | vstar 135 | refcocog_val 136 | refcocog_test 137 | clevr_ref 138 | gqa 139 | amber 140 | } 141 | run_all -------------------------------------------------------------------------------- /eval/eval_tools/clevr.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from utils.util import instantiate_from_config 3 | import argparse, json 4 | from collections import defaultdict 5 | 6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3} 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--result', type=str, default=None) 11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/CLEVR_val1k_opt_context.yaml') 12 | args = parser.parse_args() 13 | 14 | cfg = OmegaConf.load(args.config) 15 | ds = instantiate_from_config(cfg[0]) 16 | res = json.load(open(args.result)) 17 | class2res = defaultdict(list) 18 | if isinstance(res[0]['predict'], int): 19 | print('evaluating options') 20 | acc = 0 21 | for item in res: 22 | index = int(item['item_id'].split('_')[-1]) 23 | q_type = ds.question_type(index) 24 | if item['predict'] == ds.getlabel(index): 25 | c = 1 26 | acc += 1 27 | else: 28 | c = 0 29 | class2res[q_type].append(c) 30 | print('General accuracy: {:.4f} from {} samples'.format(acc/len(res), len(res))) 31 | for k,v in class2res.items(): 32 | print('{} Accuracy: {:.4f} from {} samples'.format(k, sum(v)/len(v), len(v))) 33 | return 34 | 35 | cfg = OmegaConf.load(args.config) 36 | ds = instantiate_from_config(cfg[0]) 37 | 38 | 39 | item2logit = defaultdict(list) 40 | item2answer = {} 41 | 42 | for item in res: 43 | item_id = int(item['item_id'].split('_')[-1]) 44 | question_id, option = ds.get_index(item_id) 45 | label = ds.getlabel(item_id) 46 | if question_id in item2answer: 47 | assert label == item2answer[question_id] 48 | else: 49 | item2answer[question_id] = label 50 | item2logit[question_id].append([item['logit'], option]) 51 | 52 | acc = 0 53 | for k in item2logit: 54 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1] 55 | print(preds, item2answer[k]) 56 | if preds == item2answer[k]: 57 | acc += 1 58 | print('accuracy: {}'.format(acc/len(item2logit))) 59 | 60 | if __name__=='__main__': 61 | main() -------------------------------------------------------------------------------- /eval/eval_tools/convert_res_to_amber.py: -------------------------------------------------------------------------------- 1 | import json,argparse 2 | from utils.eval_util import extract_box, remove_all_box_str 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--src', type=str) 6 | parser.add_argument('--tgt', type=str) 7 | parser.add_argument('--desc', action='store_true') 8 | args = parser.parse_args() 9 | 10 | res = json.load(open(args.src)) 11 | new_res = [] 12 | if args.desc: 13 | for item in res: 14 | key = 'prediction' if 'prediction' in item else 'predict' 15 | pred = remove_all_box_str(item[key], mistral=True).replace(' ', ' ').replace('', '').strip().replace(' .', '.') 16 | new_res.append({ 17 | 'id': item['label'], 18 | 'response': pred 19 | }) 20 | else: 21 | for item in res: 22 | tmp = item['prediction'].replace('', '').strip().lower() 23 | if 'yes' in tmp: 24 | pred = 'Yes' 25 | else: 26 | pred = 'No' 27 | new_res.append({ 28 | 'id': item['label'], 29 | 'response': pred 30 | }) 31 | with open(args.tgt, 'w') as wf: 32 | json.dump(new_res, wf) -------------------------------------------------------------------------------- /eval/eval_tools/convert_res_to_gqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from omegaconf import OmegaConf 5 | from utils.util import instantiate_from_config 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--src", type=str) 9 | parser.add_argument("--dst", type=str) 10 | args = parser.parse_args() 11 | cfg = OmegaConf.load('/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/GQA.yaml') 12 | ds = instantiate_from_config(cfg[0]) 13 | 14 | all_answers = [] 15 | res = json.load(open(args.src)) 16 | for line in res: 17 | index = int(line['item_id'].split('_')[-1]) 18 | question_id = ds.keys[index] 19 | text = line['prediction'].replace('', '').rstrip('.').strip().lower() 20 | all_answers.append({"questionId": question_id, "prediction": text}) 21 | 22 | with open(args.dst, 'w') as f: 23 | json.dump(all_answers, f) -------------------------------------------------------------------------------- /eval/eval_tools/convert_res_to_gqa_llava.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | all_answers = [] 11 | for line_idx, line in enumerate(open(args.src)): 12 | res = json.loads(line) 13 | question_id = res['question_id'] 14 | text = res['text'].rstrip('.').lower() 15 | all_answers.append({"questionId": question_id, "prediction": text}) 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(all_answers, f) -------------------------------------------------------------------------------- /eval/eval_tools/convert_res_to_mmbench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | from omegaconf import OmegaConf 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--annotation-file", type=str, required=True) 10 | parser.add_argument("--src", type=str, required=True) 11 | parser.add_argument("--tgt", type=str, required=True) 12 | 13 | return parser.parse_args() 14 | 15 | if __name__ == "__main__": 16 | args = get_args() 17 | 18 | df = pd.read_table(args.annotation_file) 19 | 20 | all_answers = [] 21 | res = json.load(open(args.src)) 22 | cur_df = df.copy() 23 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) 24 | cur_df.insert(6, 'prediction', None) 25 | for pred in res: 26 | cur_df.loc[df['index'] == pred['dataset_id'], 'prediction'] = pred['prediction'].replace('', '') 27 | 28 | cur_df.to_excel(args.tgt, index=False, engine='openpyxl') -------------------------------------------------------------------------------- /eval/eval_tools/mmbench.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from utils.util import instantiate_from_config 3 | import argparse, json 4 | from collections import defaultdict 5 | import pandas as pd 6 | 7 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3} 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--result', type=str, default=None) 12 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/MMBench_DEV_opt.yaml') 13 | args = parser.parse_args() 14 | 15 | cfg = OmegaConf.load(args.config) 16 | mmbench_meta = pd.read_table(cfg[0]['params']['path']) 17 | res = json.load(open(args.result)) 18 | for item in res: 19 | if 'predict' not in item: 20 | item['predict'] = item['prediction'] 21 | class2res = defaultdict(list) 22 | if isinstance(res[0]['predict'], int): 23 | acc = 0 24 | for item in res: 25 | index = int(item['item_id'].split('_')[-1]) 26 | q_type = mmbench_meta.iloc[index]['category'] 27 | if item['predict'] == '': 28 | item['predict'] = 0 29 | if item['predict'] == label2index[item['label']]: 30 | c = 1 31 | acc += 1 32 | else: 33 | c = 0 34 | class2res[q_type].append(c) 35 | print('General accuracy: {}'.format(acc/len(res))) 36 | for k,v in class2res.items(): 37 | print('{} Accuracy: {}'.format(k, sum(v)/len(v))) 38 | return 39 | 40 | # cfg = OmegaConf.load(args.config) 41 | # ds = instantiate_from_config(cfg[0]) 42 | 43 | 44 | # item2logit = defaultdict(list) 45 | # item2answer = {} 46 | 47 | # for item in res: 48 | # item_id = int(item['item_id'].split('_')[-1]) 49 | # question_id, option = ds.get_index(item_id) 50 | # label = ds.getlabel(item_id) 51 | # if question_id in item2answer: 52 | # assert label == item2answer[question_id] 53 | # else: 54 | # item2answer[question_id] = label 55 | # item2logit[question_id].append([item['logit'], option]) 56 | 57 | # acc = 0 58 | # for k in item2logit: 59 | # preds = sorted(item2logit[k], key=lambda x:x[0])[0][1] 60 | # print(preds, item2answer[k]) 61 | # if preds == item2answer[k]: 62 | # acc += 1 63 | # print('accuracy: {}'.format(acc/len(item2logit))) 64 | 65 | if __name__=='__main__': 66 | main() -------------------------------------------------------------------------------- /eval/eval_tools/mp3d.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from utils.util import instantiate_from_config 3 | import argparse, json 4 | from collections import defaultdict 5 | 6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3} 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--result', type=str, default=None) 11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/EmbSpa_dev_opt.yaml') 12 | args = parser.parse_args() 13 | 14 | cfg = OmegaConf.load(args.config) 15 | meta = json.load(open(cfg[0]['params']['path'])) 16 | # seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image'] 17 | # question_types = {v:k for k,v in seed_meta['question_type'].items()} 18 | res = json.load(open(args.result)) 19 | class2res = defaultdict(list) 20 | if isinstance(res[0]['predict'], int): 21 | print('evaluating options') 22 | acc = 0 23 | for item in res: 24 | index = int(item['item_id'].split('_')[-1]) 25 | item['label'] = meta[index]['answer'] 26 | # q_type = question_types[seed_question[index]['question_type_id']] 27 | if item['predict'] == item['label']: 28 | c = 1 29 | acc += 1 30 | else: 31 | c = 0 32 | # class2res[q_type].append(c) 33 | print('General accuracy: {:.4f}'.format(acc/len(res))) 34 | # for k,v in class2res.items(): 35 | # print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v))) 36 | return 37 | 38 | cfg = OmegaConf.load(args.config) 39 | ds = instantiate_from_config(cfg[0]) 40 | 41 | 42 | item2logit = defaultdict(list) 43 | item2answer = {} 44 | 45 | for item in res: 46 | item_id = int(item['item_id'].split('_')[-1]) 47 | question_id, option = ds.get_index(item_id) 48 | label = ds.getlabel(item_id) 49 | if question_id in item2answer: 50 | assert label == item2answer[question_id] 51 | else: 52 | item2answer[question_id] = label 53 | item2logit[question_id].append([item['logit'], option]) 54 | 55 | acc = 0 56 | for k in item2logit: 57 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1] 58 | print(preds, item2answer[k]) 59 | if preds == item2answer[k]: 60 | acc += 1 61 | print('accuracy: {}'.format(acc/len(item2logit))) 62 | 63 | if __name__=='__main__': 64 | main() -------------------------------------------------------------------------------- /eval/eval_tools/pope.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--data') 6 | args = parser.parse_args() 7 | res = json.load(open(args.data)) 8 | 9 | invalid = correct = 0 10 | for item in res: 11 | key = 'predict' if 'predict' in item else 'prediction' 12 | pred = item[key].replace('', '').strip().lower() 13 | if pred not in ['yes', 'no']: 14 | if pred.startswith('yes'): 15 | p = 'yes' 16 | elif pred.startswith('no'): 17 | p = 'no' 18 | else: 19 | invalid += 1 20 | p = 'no' 21 | # print(pred) 22 | # invalid += 1 23 | # correct += 1 if item['label'] == 'no' else 0 24 | else: 25 | p = pred 26 | correct += 1 if p==item['label'] else 0 27 | 28 | print('accuracy: {}, invalid rate: {}'.format(correct / len(res), invalid/len(res))) -------------------------------------------------------------------------------- /eval/eval_tools/refcoco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from utils.eval_util import * 4 | import tqdm 5 | from functools import partial 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--path', type=str, default=None) 10 | parser.add_argument('--method', type=str, default='str') 11 | parser.add_argument('--mistral', action='store_true') 12 | args = parser.parse_args() 13 | 14 | result = json.load(open(args.path)) 15 | if args.method == 'str': 16 | extract_fn = partial(extract_box_str, mistral=args.mistral) 17 | elif args.method == 'llava': 18 | extract_fn = extract_box_str_llava 19 | elif args.method == 'llava16': 20 | extract_fn = partial(extract_box_str_llava16, mistral=args.mistral) 21 | elif args.method == 'space': 22 | extract_fn = extract_box_str_space 23 | elif args.method == 'special_tokens': 24 | extract_fn = extract_box_str 25 | elif args.method == 'qwenvl': 26 | extract_fn = extract_box_str_qwenvl 27 | elif args.method == 'minigptv2': 28 | extract_fn = extract_box_str_minigptv2 29 | else: 30 | raise NotImplementedError 31 | samples = { 32 | 'fail': 0, 33 | 'wrong': 0, 34 | 'correct': 0 35 | } 36 | key = 'prediction' if 'prediction' in result[0] else 'predict' 37 | for item in result: 38 | # print(item) 39 | box = extract_fn(item[key][0] if isinstance(item[key], list) else item[key]) 40 | if box is None: 41 | print(item[key]) 42 | samples['fail'] += 1 43 | else: 44 | iou = cal_iou(box, item['label']) 45 | if iou >= 0.5: 46 | samples['correct'] += 1 47 | else: 48 | samples['wrong'] += 1 49 | print(json.dumps(samples)) 50 | print('accuracy: {}'.format(samples['correct'] / len(result))) 51 | 52 | if __name__=='__main__': 53 | main() -------------------------------------------------------------------------------- /eval/eval_tools/seed.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from utils.util import instantiate_from_config 3 | import argparse, json 4 | from collections import defaultdict 5 | 6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3} 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--result', type=str, default=None) 11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/SEED_opt.yaml') 12 | parser.add_argument('--no_counting', action='store_true') 13 | args = parser.parse_args() 14 | 15 | cfg = OmegaConf.load(args.config) 16 | seed_meta = json.load(open(cfg[0]['params']['path'])) 17 | seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image'] 18 | question_types = {v:k for k,v in seed_meta['question_type'].items()} 19 | res = json.load(open(args.result)) 20 | class2res = defaultdict(list) 21 | if isinstance(res[0]['predict'], int): 22 | print('evaluating options') 23 | acc = 0 24 | for item in res: 25 | index = int(item['item_id'].split('_')[-1]) 26 | q_type = question_types[seed_question[index]['question_type_id']] 27 | if item['predict'] == '': 28 | item['predict'] = 0 29 | if item['predict'] == label2index[item['label']]: 30 | c = 1 31 | acc += 1 32 | else: 33 | c = 0 34 | class2res[q_type].append(c) 35 | if args.no_counting: 36 | all_samples = 0 37 | corr_samples = 0 38 | for k,v in class2res.items(): 39 | if k!= 'Instances Counting': 40 | all_samples += len(v) 41 | corr_samples += sum(v) 42 | print('General accuracy w/o counting: {:.4f}'.format(corr_samples/all_samples)) 43 | print('General accuracy: {:.4f}'.format(acc/len(res))) 44 | else: 45 | print('General accuracy: {:.4f}'.format(acc/len(res))) 46 | for k,v in class2res.items(): 47 | print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v))) 48 | return 49 | 50 | cfg = OmegaConf.load(args.config) 51 | ds = instantiate_from_config(cfg[0]) 52 | 53 | 54 | item2logit = defaultdict(list) 55 | item2answer = {} 56 | 57 | for item in res: 58 | item_id = int(item['item_id'].split('_')[-1]) 59 | question_id, option = ds.get_index(item_id) 60 | label = ds.getlabel(item_id) 61 | if question_id in item2answer: 62 | assert label == item2answer[question_id] 63 | else: 64 | item2answer[question_id] = label 65 | item2logit[question_id].append([item['logit'], option]) 66 | 67 | acc = 0 68 | for k in item2logit: 69 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1] 70 | print(preds, item2answer[k]) 71 | if preds == item2answer[k]: 72 | acc += 1 73 | print('accuracy: {}'.format(acc/len(item2logit))) 74 | 75 | if __name__=='__main__': 76 | main() -------------------------------------------------------------------------------- /eval/eval_tools/vsr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import json 4 | from omegaconf import OmegaConf 5 | from utils.util import instantiate_from_config 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--data') 9 | parser.add_argument('--config') 10 | 11 | args = parser.parse_args() 12 | res = json.load(open(args.data)) 13 | cfg = OmegaConf.load(args.config) 14 | ds = instantiate_from_config(cfg[0], reload=True) 15 | 16 | invalid = correct = 0 17 | relation2res = defaultdict(list) 18 | if 'predict' in res[0] and isinstance(res[0]['predict'], int): 19 | for item in res: 20 | item_id = int(item['item_id'].split('_')[-1]) 21 | relation = ds.meta[item_id]['relation'] 22 | pred = item['predict'] 23 | correct += 1 if pred==item['label'] else 0 24 | relation2res[relation].append(1 if pred==item['label'] else 0) 25 | else: 26 | for item in res: 27 | item_id = int(item['item_id'].split('_')[-1]) 28 | relation = ds.meta[item_id]['relation'] 29 | key = 'predict' if 'predict' in item else 'prediction' 30 | pred = item[key].replace('', '').strip().lower() 31 | if pred not in ['yes', 'no']: 32 | # p = -1 33 | if pred.startswith('yes'): 34 | p = 1 35 | elif pred.startswith('no'): 36 | p = 0 37 | else: 38 | invalid += 1 39 | p = 0 40 | else: 41 | if pred == 'yes': 42 | p = 1 43 | else: 44 | p = 0 45 | correct += 1 if p==item['label'] else 0 46 | relation2res[relation].append(1 if p==item['label'] else 0) 47 | 48 | print('accuracy: {}, invalid rate: {}'.format(correct / len(res), invalid/len(res))) 49 | print('====results in detail=====') 50 | for k in sorted(relation2res.keys()): 51 | v = relation2res[k] 52 | print('{}: {} in {} samples'.format(k, round(sum(v)/len(v),4), len(v))) -------------------------------------------------------------------------------- /eval/eval_tools/vstar.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from utils.util import instantiate_from_config 3 | import argparse, json 4 | from collections import defaultdict 5 | 6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3} 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--result', type=str, default=None) 11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/SEED_opt.yaml') 12 | args = parser.parse_args() 13 | 14 | # cfg = OmegaConf.load(args.config) 15 | # seed_meta = json.load(open(cfg[0]['params']['path'])) 16 | # seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image'] 17 | # question_types = {v:k for k,v in seed_meta['question_type'].items()} 18 | res = json.load(open(args.result)) 19 | class2res = defaultdict(list) 20 | if isinstance(res[0]['predict'], int) or res[0]['predict']=='': 21 | print('evaluating options') 22 | acc = 0 23 | for item in res: 24 | index = int(item['item_id'].split('_')[-1]) 25 | # q_type = question_types[seed_question[index]['question_type_id']] 26 | if item['predict'] == item['label']: 27 | c = 1 28 | acc += 1 29 | else: 30 | c = 0 31 | # class2res[q_type].append(c) 32 | print('General accuracy: {:.4f}'.format(acc/len(res))) 33 | # for k,v in class2res.items(): 34 | # print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v))) 35 | return 36 | 37 | cfg = OmegaConf.load(args.config) 38 | ds = instantiate_from_config(cfg[0]) 39 | 40 | 41 | item2logit = defaultdict(list) 42 | item2answer = {} 43 | 44 | for item in res: 45 | item_id = int(item['item_id'].split('_')[-1]) 46 | question_id, option = ds.get_index(item_id) 47 | label = ds.getlabel(item_id) 48 | if question_id in item2answer: 49 | assert label == item2answer[question_id] 50 | else: 51 | item2answer[question_id] = label 52 | item2logit[question_id].append([item['logit'], option]) 53 | 54 | acc = 0 55 | for k in item2logit: 56 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1] 57 | print(preds, item2answer[k]) 58 | if preds == item2answer[k]: 59 | acc += 1 60 | print('accuracy: {}'.format(acc/len(item2logit))) 61 | 62 | if __name__=='__main__': 63 | main() -------------------------------------------------------------------------------- /eval/merge_benchmark.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | from utils.logger import setup_logger 4 | import json 5 | import torch 6 | 7 | def rank0_print(args, res): 8 | if args.local_rank==0 or args.local_rank == -1: 9 | print(res) 10 | 11 | def get_output_name(args, mid_output=True): 12 | if mid_output: 13 | return os.path.join(args.output_dir, 14 | '{}_rank{}.json'.format(args.dataset_name, args.local_rank)) 15 | else: 16 | return os.path.join(args.output_dir, 17 | '{}.json'.format(args.dataset_name)) 18 | 19 | def get_all_output_names(args): 20 | return [os.path.join(args.output_dir, 21 | '{}_rank{}.json'.format(args.dataset_name, r)) for r in range(args.n_gpus)] 22 | 23 | 24 | 25 | def main(): 26 | parser = ArgumentParser() 27 | parser.add_argument('--config_arg', type=str, default=None) 28 | old_args = parser.parse_args() 29 | 30 | args = torch.load(old_args.config_arg) 31 | print(args) 32 | 33 | base_config_name = os.path.basename(args.eval_data) 34 | args.dataset_name = base_config_name[:-5] if base_config_name.endswith('.yaml') else base_config_name 35 | 36 | 37 | full_res = [] 38 | for fn in get_all_output_names(args): 39 | full_res.extend(json.load(open(fn, 'r'))) 40 | os.remove(fn) 41 | with open(get_output_name(args, mid_output=False), 'w') as wf: 42 | json.dump(full_res, wf) 43 | # saving the arguments 44 | torch.save(args, get_output_name(args, mid_output=False)[:-4]+'args.bin') 45 | 46 | 47 | if __name__=='__main__': 48 | main() -------------------------------------------------------------------------------- /figs/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/figs/model_arch.png -------------------------------------------------------------------------------- /figs/sample_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/figs/sample_input.jpg -------------------------------------------------------------------------------- /locals/datasets/image_caption/cc3m.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | from constants import * 8 | 9 | class FilteredCC3MI2TDataset(Dataset): 10 | 11 | def __init__(self, 12 | path: str, 13 | image_folder: str, 14 | meta_folder: str=None, 15 | instruct: bool = False, 16 | min_resize_res: int = 256, 17 | max_resize_res: int = 256, 18 | crop_res: int = 256, 19 | flip_prob: float = 0.5, 20 | sample_weight: float = 1.0, 21 | check: bool = False, 22 | output_mode: str = 'text', 23 | shuffle: bool = False, 24 | raw_image: bool = False, 25 | inference: bool = False, 26 | **kwargs): 27 | # load from json ../datasets/gqa-inpaint/meta_info.json 28 | self.path = path 29 | self.instruct = instruct 30 | self.inference = inference 31 | if path.endswith('txt'): 32 | self.meta = [] 33 | with open(path) as rf: 34 | for line in rf: 35 | self.meta.append(line.strip()) 36 | self.need_reload = True 37 | else: 38 | self.need_reload = False 39 | self.meta = js.load(open(path)) 40 | if meta_folder is not None: 41 | self.meta_folder = os.path.join(os.path.dirname(path), 'meta') 42 | else: 43 | self.meta_folder = meta_folder 44 | self.image_folder = image_folder 45 | self.min_resize_res = min_resize_res 46 | self.max_resize_res = max_resize_res 47 | self.crop_res = crop_res 48 | self.check = check 49 | self.raw_image = raw_image 50 | self.output_mode = output_mode 51 | self.shuffle = shuffle 52 | 53 | self.flip_prob = flip_prob 54 | self.sample_weight = sample_weight 55 | self.generation_prompts = [ 56 | "generate image with caption:", 57 | "can you give me the image with caption:", 58 | "help me to generate this image:", 59 | "generate image with according to caption:", 60 | "according to caption, generate image:", 61 | "an image with caption:", 62 | "can you visualize this caption:", 63 | ] 64 | print(f"CC3MDataset has {len(self)} samples!!") 65 | 66 | def __len__(self): 67 | return int(len(self.meta) * self.sample_weight) 68 | 69 | def __getitem__(self, i): 70 | 71 | if self.sample_weight >= 1: 72 | i = i % len(self.meta) 73 | else: 74 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 75 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 76 | 77 | if self.need_reload: 78 | meta_fn = self.meta[i] 79 | item = js.load(open(os.path.join(self.meta_folder, meta_fn))) 80 | else: 81 | item = self.meta[i] 82 | # item = self.meta[i] 83 | tgt_img = Image.open(os.path.join(self.image_folder,item['image'])).convert('RGB') 84 | instruction = item['conversations'][1]['value'] 85 | # return image_0, image_1, instruction 86 | 87 | if self.output_mode == 'conversation': 88 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction}, 89 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 90 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]} 91 | elif self.output_mode == 'text': 92 | if self.shuffle: 93 | prob = random.random() 94 | if prob > 0.5: 95 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction) 96 | image_label_masks = [0] 97 | else: 98 | text = "{} {}".format(instruction, ALL_IMG_TOKENS_STR) 99 | image_label_masks = [1] 100 | else: 101 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction) 102 | image_label_masks = [0] 103 | if self.inference: 104 | text = ALL_IMG_TOKENS_STR 105 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks} -------------------------------------------------------------------------------- /locals/datasets/image_caption/coco.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | from torch.utils.data import Dataset 3 | import json as js 4 | import math 5 | import random 6 | import os 7 | from PIL import Image 8 | from collections import defaultdict 9 | from constants import * 10 | 11 | class COCOI2TDataset(Dataset): 12 | 13 | def __init__(self, 14 | path: str, 15 | image_folder: str, 16 | instruct: bool = False, 17 | min_resize_res: int = 256, 18 | max_resize_res: int = 256, 19 | crop_res: int = 256, 20 | flip_prob: float = 0.5, 21 | sample_weight: float = 1.0, 22 | check: bool = False, 23 | output_mode: str = 'text', 24 | raw_image: bool = False, 25 | sample_mode: str='text', 26 | shuffle: bool=False, 27 | inference: bool=False, 28 | **kwargs): 29 | # load from json ../datasets/gqa-inpaint/meta_info.json 30 | self.path = path 31 | self.instruct = instruct 32 | self.inference = inference 33 | self.meta = js.load(open(path)) 34 | if isinstance(self.meta, dict): 35 | self.meta = self.meta['annotations'] 36 | self.sample_mode = sample_mode 37 | if self.sample_mode == 'image': 38 | tmp = defaultdict(list) 39 | for item in self.meta: 40 | tmp[item['image_id']].append(item['caption']) 41 | self.meta = [{'image_id': k, 'captions': v} for k,v in tmp.items()] 42 | self.image_folder = image_folder 43 | self.min_resize_res = min_resize_res 44 | self.max_resize_res = max_resize_res 45 | self.crop_res = crop_res 46 | self.check = check 47 | self.raw_image = raw_image 48 | self.output_mode = output_mode 49 | self.shuffle = shuffle # shuffle is for interleaved image-text data 50 | 51 | self.flip_prob = flip_prob 52 | self.sample_weight = sample_weight 53 | self.generation_prompts = [ 54 | "generate image with caption:", 55 | "can you give me the image with caption:", 56 | "help me to generate this image:", 57 | "generate image with according to caption:", 58 | "according to caption, generate image:", 59 | "an image with caption:", 60 | "can you visualize this caption:", 61 | ] 62 | print(f"COCO has {len(self)} samples!!") 63 | 64 | def __len__(self): 65 | return int(len(self.meta) * self.sample_weight) 66 | 67 | def get_image_fn(self, image_id): 68 | return os.path.join(self.image_folder, 'train2014', 'COCO_train2014_{:012d}.jpg'.format(image_id)) 69 | 70 | def __getitem__(self, i): 71 | 72 | if self.sample_weight >= 1: 73 | i = i % len(self.meta) 74 | else: 75 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 76 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 77 | 78 | item = self.meta[i] 79 | tgt_img = Image.open(self.get_image_fn(item['image_id'])).convert('RGB') 80 | if self.sample_mode == 'text': 81 | instruction = item['caption'] 82 | elif self.sample_mode == 'image': 83 | instruction = random.choice(item['captions']) 84 | else: 85 | raise ValueError 86 | # return image_0, image_1, instruction 87 | 88 | if self.output_mode == 'conversation': 89 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction}, 90 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 91 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]} 92 | elif self.output_mode == 'text': 93 | if self.shuffle: 94 | prob = random.random() 95 | if prob > 0.5: 96 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction) 97 | image_label_masks = [0] 98 | else: 99 | text = "{} {}".format(instruction, ALL_IMG_TOKENS_STR) 100 | image_label_masks = [1] 101 | else: 102 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction) 103 | image_label_masks = [0] 104 | if self.inference: 105 | text = ALL_IMG_TOKENS_STR 106 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks} -------------------------------------------------------------------------------- /locals/datasets/image_caption/flicker30k.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | from torch.utils.data import Dataset 3 | import json as js 4 | import math 5 | import random 6 | import os 7 | from PIL import Image 8 | from collections import defaultdict 9 | from constants import * 10 | from ..utils.box_utils import box2str, reshape_box 11 | 12 | class FlickrDataset(Dataset): 13 | 14 | def __init__(self, 15 | path: str, 16 | image_folder: str, 17 | instruct: bool = False, 18 | sample_weight: float = 1.0, 19 | check: bool = False, 20 | output_mode: str = 'text', 21 | raw_image: bool = False, 22 | shuffle: bool=False, 23 | avoid_image_gen: bool=False, 24 | phrase_format: str='special_tokens', 25 | phrase_prec: int=2, 26 | expand2square: bool=False, 27 | object_format: str='image', 28 | phrase_space: bool=False, 29 | **kwargs): 30 | # load from json ../datasets/gqa-inpaint/meta_info.json 31 | self.path = path 32 | self.instruct = instruct 33 | self.meta = [js.loads(line) for line in open(self.path)] 34 | self.image_folder = image_folder 35 | self.raw_image = raw_image 36 | self.object_format = object_format 37 | self.output_mode = output_mode 38 | self.avoid_image_gen = avoid_image_gen 39 | self.phrase_format = phrase_format 40 | self.phrase_prec = phrase_prec 41 | self.shuffle = shuffle # shuffle is for interleaved image-text data 42 | self.expand2square = expand2square 43 | self.phrase_space = phrase_space 44 | 45 | self.sample_weight = sample_weight 46 | print(f"Flickr30k entities has {len(self)} samples!!") 47 | 48 | def __len__(self): 49 | return int(len(self.meta) * self.sample_weight) 50 | 51 | def proc_box(self, box, image): 52 | w, h = image.size 53 | x_min, y_min, x_max, y_max = box 54 | sub_image = image.crop((x_min, y_min, x_max, y_max)) 55 | new_box = [c / w if (i%2==0) else c / h for i,c in enumerate(box)] 56 | if self.expand2square: 57 | new_box = reshape_box(image, new_box) 58 | return new_box, sub_image 59 | 60 | def __len__(self): 61 | return len(self.meta) 62 | 63 | 64 | def __getitem__(self, i): 65 | 66 | if self.sample_weight >= 1: 67 | i = i % len(self.meta) 68 | else: 69 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 70 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 71 | 72 | item = self.meta[i] 73 | image_id = item['image_id'] 74 | image = Image.open(os.path.join(self.image_folder, '{}.jpg'.format(image_id))).convert('RGB') 75 | caption = item['sentence'] 76 | 77 | box_infos = [] 78 | sub_image_infos = [] 79 | for b in item['boxes']: 80 | current_reshape_infos = self.proc_box(b, image) 81 | box_infos.append(current_reshape_infos[0]) 82 | sub_image_infos.append(current_reshape_infos[1]) 83 | all_box_input = [[0.0, 0.0, 1.0, 1.0]] 84 | image_label_masks = [0] 85 | sub_image = [] 86 | input_images = [image] 87 | 88 | # start from 89 | input_query = item['sentence'].replace('', '').split('') 90 | refined_query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + ' ' 91 | for box_ind, current_box in enumerate(item['boxes_seq']): 92 | sub_image.extend(sub_image_infos[c] for c in current_box) 93 | current_box = [box_infos[c] for c in current_box] 94 | all_box_input.extend(current_box) 95 | image_label_masks.extend([0]*len(current_box)) 96 | if self.object_format == 'coordinate': 97 | box_in_str = DEFAULT_SEP_TOKEN.join([DEFAULT_BOC_TOKEN + box2str(c, mode=self.phrase_format, prec=self.phrase_prec, space=self.phrase_space) + DEFAULT_EOC_TOKEN for c in current_box]) 98 | else: 99 | box_in_str = DEFAULT_SEP_TOKEN.join([DEFAULT_BOC_TOKEN + box2str(c, mode=self.phrase_format, prec=self.phrase_prec, space=self.phrase_space) + DEFAULT_EOC_TOKEN + ALL_IMG_TOKENS_STR for c in current_box]) 100 | refined_query = refined_query + input_query[box_ind] + box_in_str 101 | refined_query = refined_query + input_query[-1] 102 | 103 | if self.output_mode == 'conversation': 104 | query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + '\n' + 'Briefly describe this image. Locate objects and provide bounding boxes in your response.' 105 | response = refined_query.replace(ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + ' ', '') 106 | conversation = [{'from': 'human', 'value': query}, {'from': 'gpt', 'value': response}] 107 | return {'input_images': input_images, 'conversation': conversation, 'image_label_masks': image_label_masks, 'box': all_box_input} 108 | raise ValueError 109 | elif self.output_mode == 'text': 110 | # print(instruction) 111 | if self.object_format == 'image': 112 | return {'input_images': input_images + sub_image, 'text': refined_query, 'image_label_masks': image_label_masks} 113 | elif self.object_format == 'representation': 114 | return {'input_images': input_images, 115 | 'text': refined_query, 116 | 'image_label_masks': image_label_masks, 117 | 'box': all_box_input} 118 | elif self.object_format == 'coordinate': 119 | return {'input_images': input_images, 120 | 'text': refined_query, 121 | 'image_label_masks': image_label_masks, 122 | 'box': all_box_input[:1]} 123 | else: 124 | raise ValueError -------------------------------------------------------------------------------- /locals/datasets/image_caption/fusecap.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | from constants import * 8 | import tqdm 9 | 10 | class FuseCapCC3MI2TDataset(Dataset): 11 | 12 | def __init__(self, 13 | path: str, 14 | image_folder: str, 15 | instruct: bool = False, 16 | min_resize_res: int = 256, 17 | max_resize_res: int = 256, 18 | crop_res: int = 256, 19 | flip_prob: float = 0.5, 20 | sample_weight: float = 1.0, 21 | check: bool = False, 22 | output_mode: str = 'text', 23 | raw_image: bool = False, 24 | **kwargs): 25 | # load from json ../datasets/gqa-inpaint/meta_info.json 26 | self.path = path 27 | self.instruct = instruct 28 | self.meta = js.load(open(path)) 29 | self.image_folder = image_folder 30 | self.min_resize_res = min_resize_res 31 | self.max_resize_res = max_resize_res 32 | self.crop_res = crop_res 33 | self.check = check 34 | self.raw_image = raw_image 35 | self.output_mode = output_mode 36 | 37 | # check the image existence 38 | valid_items = [] 39 | for item in tqdm.tqdm(self.meta): 40 | if os.path.exists(self.get_image_name(item['image_id'])): 41 | valid_items.append(item) 42 | self.meta = valid_items 43 | 44 | self.flip_prob = flip_prob 45 | self.sample_weight = sample_weight 46 | self.generation_prompts = [ 47 | "generate image with caption:", 48 | "can you give me the image with caption:", 49 | "help me to generate this image:", 50 | "generate image with according to caption:", 51 | "according to caption, generate image:", 52 | "an image with caption:", 53 | "can you visualize this caption:", 54 | ] 55 | print(f"CC3MDataset has {len(self)} samples!!") 56 | 57 | def __len__(self): 58 | return int(len(self.meta) * self.sample_weight) 59 | 60 | def get_image_name(self, image_id): 61 | return os.path.join(self.image_folder, 'GCC_train_{:09d}.jpg'.format(image_id)) 62 | 63 | def __getitem__(self, i): 64 | 65 | if self.sample_weight >= 1: 66 | i = i % len(self.meta) 67 | else: 68 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 69 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 70 | 71 | item = self.meta[i] 72 | tgt_img = Image.open(os.path.join(self.image_folder,item['image'])).convert('RGB') 73 | instruction = item['conversations'][1]['value'] 74 | # return image_0, image_1, instruction 75 | 76 | if self.output_mode == 'conversation': 77 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction}, 78 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 79 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]} 80 | elif self.output_mode == 'text': 81 | return {'input_images': [tgt_img], 'text': "{} {}".format(ALL_IMG_TOKENS_STR, instruction), 'image_label_masks': [0]} 82 | 83 | class FuseCapCC12MI2TDataset(FuseCapCC3MI2TDataset): 84 | def get_image_name(self, image_id): 85 | image_fn = '{:08d}.jpg'.format(image_id) 86 | return os.path.join(self.image_folder, '{}'.format(int(image_fn[:2])), image_fn) -------------------------------------------------------------------------------- /locals/datasets/image_caption/lvis_gpt4v.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | from torch.utils.data import Dataset 3 | import json as js 4 | import math 5 | import random 6 | import os 7 | from PIL import Image 8 | from collections import defaultdict 9 | from constants import * 10 | 11 | def load_prompt(fn): 12 | prompts = [] 13 | with open(fn) as f: 14 | line=f.readline() 15 | while line: 16 | line=line.strip('\n') 17 | prompts.append(line) 18 | line=f.readline() 19 | return prompts 20 | 21 | class LvisCapDataset(Dataset): 22 | def __init__(self, 23 | path: str, 24 | image_folder: str, 25 | instruct: bool = False, 26 | min_resize_res: int = 256, 27 | max_resize_res: int = 256, 28 | crop_res: int = 256, 29 | flip_prob: float = 0.5, 30 | sample_weight: float = 1.0, 31 | check: bool = False, 32 | output_mode: str = 'conversation', 33 | raw_image: bool = False, 34 | shuffle: bool=False, 35 | shuffle_prob: float=0.5, 36 | inference: bool=False, 37 | caption_type: str='long', 38 | task_type: str='i2t', 39 | **kwargs): 40 | # load from json ../datasets/gqa-inpaint/meta_info.json 41 | self.path = path 42 | self.instruct = instruct 43 | self.inference = inference 44 | self.meta = js.load(open(path)) 45 | self.image_folder = image_folder 46 | self.min_resize_res = min_resize_res 47 | self.max_resize_res = max_resize_res 48 | self.crop_res = crop_res 49 | self.check = check 50 | self.raw_image = raw_image 51 | self.output_mode = output_mode 52 | self.shuffle = shuffle # shuffle is for interleaved image-text data 53 | self.shuffle_prob = shuffle_prob # the probability for shuffle 54 | 55 | self.flip_prob = flip_prob 56 | self.task_type = task_type 57 | self.sample_weight = sample_weight 58 | self.caption_type = caption_type 59 | self.i2t_long_prompts = load_prompt('locals/datasets/prompts/prompt_captioning_long.txt') 60 | self.i2t_short_prompts = load_prompt('locals/datasets/prompts/prompt_captioning_short.txt') 61 | self.t2i_prompts = load_prompt('locals/datasets/prompts/prompt_txt2img.txt') 62 | print(f"LVIS-GPT4V-Captions has {len(self)} samples!!") 63 | 64 | def __len__(self): 65 | return int(len(self.meta) * self.sample_weight) 66 | 67 | def get_image_fn(self, image_id): 68 | return os.path.join(self.image_folder, image_id) 69 | 70 | def __getitem__(self, i): 71 | 72 | if self.sample_weight >= 1: 73 | i = i % len(self.meta) 74 | else: 75 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 76 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 77 | 78 | item = self.meta[i] 79 | tgt_img = Image.open(self.get_image_fn(item['image'])).convert('RGB') 80 | 81 | # find the caption 82 | if self.caption_type == 'random': 83 | prob = random.random() 84 | if prob > 0.5: 85 | caption_type = 'long' 86 | else: 87 | caption_type = 'short' 88 | else: 89 | caption_type = self.caption_type 90 | if caption_type == 'long': 91 | caption = item['caption'] 92 | prompt = random.choice(self.i2t_long_prompts) 93 | elif caption_type == 'short': 94 | caption = item['short_caption'] 95 | prompt = random.choice(self.i2t_short_prompts) 96 | else: 97 | raise ValueError 98 | 99 | if self.output_mode == 'conversation': 100 | if self.shuffle: 101 | prob = random.random() 102 | if prob > self.shuffle_prob: 103 | task_type = 'i2t' 104 | else: 105 | task_type = 't2i' 106 | else: 107 | task_type = self.task_type 108 | if task_type == 'i2t': 109 | instruction = '{} {}'.format(ALL_IMG_TOKENS_STR, prompt) 110 | response = caption 111 | label_masks = [0] 112 | elif task_type == 't2i': 113 | prompt = random.choice(self.t2i_prompts) 114 | instruction = '{} {}'.format(prompt, caption) 115 | response = ALL_IMG_TOKENS_STR 116 | label_masks = [1] 117 | else: 118 | raise ValueError 119 | sources = [{'from': 'human', 'value': instruction}, 120 | {'from': 'gpt', 'value': response}] 121 | return {'input_imges': [tgt_img], 'conversation': sources, 'image_label_masks': label_masks} 122 | elif self.output_mode == 'text': 123 | # raise ValueError 124 | if self.shuffle: 125 | prob = random.random() 126 | if prob > self.shuffle_prob: 127 | task_type = 'i2t' 128 | else: 129 | task_type = 't2i' 130 | else: 131 | task_type = self.task_type 132 | if task_type == 'i2t': 133 | text = "{} {}".format(ALL_IMG_TOKENS_STR, caption) 134 | image_label_masks = [0] 135 | elif task_type == 't2i': 136 | text = "{} {}".format(caption, ALL_IMG_TOKENS_STR) 137 | image_label_masks = [1] 138 | else: 139 | raise ValueError 140 | if self.inference: 141 | text = ALL_IMG_TOKENS_STR 142 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks} -------------------------------------------------------------------------------- /locals/datasets/image_caption/mmc4.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | from constants import * 8 | import time 9 | import os 10 | from pathlib import Path 11 | class FilteredMMC4Dataset(Dataset): 12 | 13 | def __init__(self, 14 | path: str, 15 | image_folder: str = '', 16 | instruct: bool = False, 17 | min_resize_res: int = 256, 18 | max_resize_res: int = 256, 19 | crop_res: int = 256, 20 | flip_prob: float = 0.5, 21 | sample_weight: float = 1.0, 22 | check: bool = False, 23 | output_mode: str = 'text', 24 | raw_image: bool = False, 25 | meta_folder: str = None, 26 | expand2square: bool = False, 27 | avoid_image_gen: bool=False, 28 | **kwargs): 29 | 30 | time_start = time.time() 31 | self.path = path 32 | self.instruct = instruct 33 | self.meta_folder = meta_folder 34 | self.expand2square = expand2square 35 | self.avoid_image_gen = avoid_image_gen 36 | 37 | if os.path.isdir(path): 38 | self.meta = list(Path(path).rglob('*.json')) 39 | self.need_reload = False 40 | elif os.path.isfile(path) and path.endswith('jsonl'): 41 | self.meta_file = open(path) 42 | self.meta = [] 43 | _ = 0 44 | for line in self.meta_file: 45 | self.meta.append(js.loads(line)) 46 | _ +=1 47 | self.need_reload = False 48 | elif os.path.isfile(path) and path.endswith('txt'): 49 | self.meta = [] 50 | with open(path) as rf: 51 | for line in rf: 52 | self.meta.append(line.strip()) 53 | self.need_reload = True 54 | assert self.meta_folder is not None 55 | else: 56 | raise ValueError('Invalid data path') 57 | # self.meta_file = open(path) 58 | # self.meta = [] 59 | # _ = 0 60 | # for line in self.meta_file: 61 | # self.meta.append(js.loads(line)) 62 | # _ +=1 63 | 64 | #### <<<<< Delete follow code for real training >>>> #### 65 | 66 | # if _ > 5000: 67 | # break 68 | 69 | ######################################################### 70 | 71 | self.image_folder = image_folder 72 | self.min_resize_res = min_resize_res 73 | self.max_resize_res = max_resize_res 74 | self.crop_res = crop_res 75 | self.check = check 76 | self.raw_image = raw_image 77 | self.output_mode = output_mode 78 | 79 | self.flip_prob = flip_prob 80 | self.sample_weight = sample_weight 81 | time_end = time.time() 82 | print(f"MMC4 Dataset has {len(self)} samples!!, initialized with {time_end-time_start}") 83 | 84 | def __len__(self): 85 | return int(len(self.meta) * self.sample_weight) 86 | 87 | def __getitem__(self, i): 88 | 89 | if self.sample_weight >= 1: 90 | i = i % len(self.meta) 91 | else: 92 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 93 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 94 | 95 | item = self.meta[i] 96 | # if item is a path 97 | if isinstance(item,Path): 98 | item = js.load(open(item)) 99 | elif self.need_reload: 100 | item = js.load(open(os.path.join(self.meta_folder, item))) 101 | 102 | tgt_img_list = [] 103 | img_txt_match_index = [] 104 | for img_info in item['image_info']: 105 | tgt_img_list.append(Image.open(os.path.join(self.image_folder,img_info['img_path'])).convert('RGB')) 106 | img_txt_match_index.append(img_info['matched_text_index']) 107 | text_list = [txt for txt in item['text_list']] 108 | img_label_masks = [] 109 | for i,match_index in enumerate(img_txt_match_index): 110 | thr = random.random() 111 | if i == 0 and thr <= 0.5: 112 | img_label_masks.append(0) 113 | else: 114 | if self.avoid_image_gen: 115 | img_label_masks.append(0) 116 | else: 117 | img_label_masks.append(1) 118 | text_list[match_index] = "{} {}".format(text_list[match_index], ALL_IMG_TOKENS_STR) if thr > 0.5 else "{} {}".format(ALL_IMG_TOKENS_STR, text_list[match_index]) 119 | # return image_0, image_1, instruction 120 | 121 | assert self.output_mode == 'text' 122 | return {'input_images': tgt_img_list, 'text': ' '.join(text_list) , 'image_label_masks': img_label_masks} 123 | 124 | 125 | if __name__ == "__main__": 126 | dataset = FilteredMMC4Dataset('/mnt/bn/luoruipu-disk/meta_data/mmc4/filter_mmc4_meta_with_img_abs_path.jsonl') 127 | print(repr(dataset[1000])) -------------------------------------------------------------------------------- /locals/datasets/image_edit/low_level/clwd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InstructDiffusion 3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) 4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn) 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | import torch 11 | from PIL import Image 12 | import torchvision.transforms.functional as TF 13 | from pdb import set_trace as stx 14 | import random 15 | import cv2 16 | from PIL import Image 17 | import torchvision 18 | from constants import ALL_IMG_TOKENS_STR 19 | 20 | 21 | def is_image_file(filename): 22 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 23 | 24 | 25 | class CLWD(Dataset): 26 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False, 27 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_dewatermark.txt'): 28 | super(CLWD, self).__init__() 29 | 30 | inp_files = sorted(os.listdir(os.path.join(path, split, 'Watermarked_image'))) 31 | tar_files = sorted(os.listdir(os.path.join(path, split, 'Watermark_free_image'))) 32 | 33 | self.inp_filenames = [os.path.join(path, split, 'Watermarked_image', x) for x in inp_files if is_image_file(x)] 34 | self.tar_filenames = [os.path.join(path, split, 'Watermark_free_image', x) for x in tar_files if is_image_file(x)] 35 | 36 | self.size = size 37 | self.flip_prob = flip_prob 38 | self.sample_weight = sample_weight 39 | self.instruct = instruct 40 | self.check = check 41 | self.raw_image = raw_image 42 | self.sizex = len(self.tar_filenames) # get the size of target 43 | 44 | self.interpolation = { 45 | "cv_nearest": cv2.INTER_NEAREST, 46 | "cv_bilinear": cv2.INTER_LINEAR, 47 | "cv_bicubic": cv2.INTER_CUBIC, 48 | "cv_area": cv2.INTER_AREA, 49 | "cv_lanczos": cv2.INTER_LANCZOS4, 50 | "pil_nearest": Image.NEAREST, 51 | "pil_bilinear": Image.BILINEAR, 52 | "pil_bicubic": Image.BICUBIC, 53 | "pil_box": Image.BOX, 54 | "pil_hamming": Image.HAMMING, 55 | "pil_lanczos": Image.LANCZOS, 56 | }[interpolation] 57 | 58 | # prompt_path='dataset/prompt/prompt_dewatermark.txt' 59 | self.prompt_list=[] 60 | with open(prompt_path) as f: 61 | line=f.readline() 62 | while line: 63 | line=line.strip('\n') 64 | self.prompt_list.append(line) 65 | line=f.readline() 66 | 67 | print(f"CLWD has {len(self)} samples!!") 68 | 69 | def __len__(self): 70 | return int(self.sizex * self.sample_weight) 71 | 72 | def __getitem__(self, index): 73 | if self.raw_image: 74 | return self.get_raw_image(index) 75 | else: 76 | return self.get_processed_image(index) 77 | 78 | def get_processed_image(self, index): 79 | if self.sample_weight >= 1: 80 | index_ = index % self.sizex 81 | else: 82 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 83 | 84 | inp_path = self.inp_filenames[index_] 85 | tar_path = self.tar_filenames[index_] 86 | 87 | inp_img = Image.open(inp_path) 88 | tar_img = Image.open(tar_path) 89 | 90 | width, height = inp_img.size 91 | tar_width, tar_height = tar_img.size 92 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 93 | aspect_ratio = float(width) / float(height) 94 | if width < height: 95 | new_width = self.size 96 | new_height = int(self.size / aspect_ratio) 97 | else: 98 | new_height = self.size 99 | new_width = int(self.size * aspect_ratio) 100 | inp_img = inp_img.resize((new_width, new_height), self.interpolation) 101 | tar_img = tar_img.resize((new_width, new_height), self.interpolation) 102 | 103 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1) 104 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32)) 105 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1) 106 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32)) 107 | crop = torchvision.transforms.RandomCrop(self.size) 108 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) 109 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2) 110 | 111 | prompt = random.choice(self.prompt_list) 112 | if self.instruct: 113 | prompt = "Watermark Removal: " + prompt 114 | 115 | if self.check: 116 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt)) 117 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt)) 118 | 119 | def get_raw_image(self, index): 120 | if self.sample_weight >= 1: 121 | index_ = index % self.sizex 122 | else: 123 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 124 | 125 | inp_path = self.inp_filenames[index_] 126 | tar_path = self.tar_filenames[index_] 127 | 128 | inp_img = Image.open(inp_path).convert('RGB') 129 | tar_img = Image.open(tar_path).convert('RGB') 130 | 131 | width, height = inp_img.size 132 | tar_width, tar_height = tar_img.size 133 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 134 | 135 | prompt = random.choice(self.prompt_list) 136 | if self.instruct: 137 | prompt = "Watermark Removal: " + prompt 138 | 139 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 140 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 141 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0,1]} -------------------------------------------------------------------------------- /locals/datasets/image_edit/low_level/gopro.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InstructDiffusion 3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) 4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn) 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | import torch 11 | from PIL import Image 12 | import torchvision.transforms.functional as TF 13 | from pdb import set_trace as stx 14 | import random 15 | import cv2 16 | from PIL import Image 17 | import torchvision 18 | from constants import ALL_IMG_TOKENS_STR 19 | 20 | 21 | def is_image_file(filename): 22 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 23 | 24 | 25 | class GoPro(Dataset): 26 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False, 27 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_deblur.txt'): 28 | super(GoPro, self).__init__() 29 | 30 | # inp_files = sorted(os.listdir(os.path.join(path, split, 'input'))) 31 | # tar_files = sorted(os.listdir(os.path.join(path, split, 'target'))) 32 | 33 | # self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)] 34 | # self.tar_filenames = [os.path.join(path, split, 'target', x) for x in tar_files if is_image_file(x)] 35 | 36 | filenames_splits = sorted(os.listdir(os.path.join(path, split))) 37 | 38 | self.inp_filenames = [os.path.join(path, split, d, 'blur', x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, split, d, 'blur'))) if is_image_file(x)] 39 | self.tar_filenames = [os.path.join(path, split, d, 'sharp', x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, split, d, 'sharp'))) if is_image_file(x)] 40 | 41 | self.size = size 42 | self.flip_prob = flip_prob 43 | self.sample_weight = sample_weight 44 | self.instruct = instruct 45 | self.check = check 46 | self.raw_image = raw_image 47 | self.sizex = len(self.tar_filenames) # get the size of target 48 | 49 | self.interpolation = { 50 | "cv_nearest": cv2.INTER_NEAREST, 51 | "cv_bilinear": cv2.INTER_LINEAR, 52 | "cv_bicubic": cv2.INTER_CUBIC, 53 | "cv_area": cv2.INTER_AREA, 54 | "cv_lanczos": cv2.INTER_LANCZOS4, 55 | "pil_nearest": Image.NEAREST, 56 | "pil_bilinear": Image.BILINEAR, 57 | "pil_bicubic": Image.BICUBIC, 58 | "pil_box": Image.BOX, 59 | "pil_hamming": Image.HAMMING, 60 | "pil_lanczos": Image.LANCZOS, 61 | }[interpolation] 62 | 63 | # prompt_path='dataset/prompt/prompt_deblur.txt' 64 | self.prompt_list=[] 65 | with open(prompt_path) as f: 66 | line=f.readline() 67 | while line: 68 | line=line.strip('\n') 69 | self.prompt_list.append(line) 70 | line=f.readline() 71 | 72 | print(f"GoPro has {len(self)} samples!!") 73 | 74 | def __len__(self): 75 | return int(self.sizex * self.sample_weight) 76 | 77 | def __getitem__(self, index): 78 | if self.raw_image: 79 | return self.get_raw_image(index) 80 | else: 81 | return self.get_processed_image(index) 82 | 83 | def get_processed_image(self, index): 84 | if self.sample_weight >= 1: 85 | index_ = index % self.sizex 86 | else: 87 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 88 | 89 | inp_path = self.inp_filenames[index_] 90 | tar_path = self.tar_filenames[index_] 91 | 92 | inp_img = Image.open(inp_path) 93 | tar_img = Image.open(tar_path) 94 | 95 | width, height = inp_img.size 96 | tar_width, tar_height = tar_img.size 97 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 98 | aspect_ratio = float(width) / float(height) 99 | if width < height: 100 | new_width = self.size 101 | new_height = int(self.size / aspect_ratio) 102 | else: 103 | new_height = self.size 104 | new_width = int(self.size * aspect_ratio) 105 | inp_img = inp_img.resize((new_width, new_height), self.interpolation) 106 | tar_img = tar_img.resize((new_width, new_height), self.interpolation) 107 | 108 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1) 109 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32)) 110 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1) 111 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32)) 112 | crop = torchvision.transforms.RandomCrop(self.size) 113 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) 114 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2) 115 | 116 | prompt = random.choice(self.prompt_list) 117 | if self.instruct: 118 | prompt = "Image Deblurring: " + prompt 119 | 120 | if self.check: 121 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt)) 122 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt)) 123 | 124 | def get_raw_image(self, index): 125 | if self.sample_weight >= 1: 126 | index_ = index % self.sizex 127 | else: 128 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 129 | 130 | inp_path = self.inp_filenames[index_] 131 | tar_path = self.tar_filenames[index_] 132 | 133 | inp_img = Image.open(inp_path).convert('RGB') 134 | tar_img = Image.open(tar_path).convert('RGB') 135 | 136 | width, height = inp_img.size 137 | tar_width, tar_height = tar_img.size 138 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 139 | 140 | prompt = random.choice(self.prompt_list) 141 | if self.instruct: 142 | prompt = "Image Deblurring: " + prompt 143 | 144 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 145 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 146 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]} -------------------------------------------------------------------------------- /locals/datasets/image_edit/low_level/reds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms.functional as TF 7 | from pdb import set_trace as stx 8 | import random 9 | import cv2 10 | from PIL import Image 11 | import torchvision 12 | from constants import ALL_IMG_TOKENS_STR 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 17 | 18 | 19 | class REDS(Dataset): 20 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False, 21 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_deblur.txt'): 22 | super(REDS, self).__init__() 23 | 24 | inp_files = sorted(os.listdir(os.path.join(path, split, 'train_blur'))) 25 | tar_files = sorted(os.listdir(os.path.join(path, split, 'train_sharp'))) 26 | 27 | if split == "train": 28 | self.inp_filenames = [os.path.join(path, split, 'train_blur', d, x) for d in inp_files for x in sorted(os.listdir(os.path.join(path, split, 'train_blur', d))) if is_image_file(x)] 29 | self.tar_filenames = [os.path.join(path, split, 'train_sharp', d, x) for d in tar_files for x in sorted(os.listdir(os.path.join(path, split, 'train_sharp', d))) if is_image_file(x)] 30 | else: 31 | self.inp_filenames = [os.path.join(path, split, 'blur', x) for x in inp_files if is_image_file(x)] 32 | self.tar_filenames = [os.path.join(path, split, 'sharp', x) for x in tar_files if is_image_file(x)] 33 | 34 | self.size = size 35 | self.flip_prob = flip_prob 36 | self.sample_weight = sample_weight 37 | self.instruct = instruct 38 | self.raw_image = raw_image 39 | assert len(self.inp_filenames) == len(self.tar_filenames) 40 | self.check = check 41 | self.sizex = len(self.tar_filenames) # get the size of target 42 | 43 | self.interpolation = { 44 | "cv_nearest": cv2.INTER_NEAREST, 45 | "cv_bilinear": cv2.INTER_LINEAR, 46 | "cv_bicubic": cv2.INTER_CUBIC, 47 | "cv_area": cv2.INTER_AREA, 48 | "cv_lanczos": cv2.INTER_LANCZOS4, 49 | "pil_nearest": Image.NEAREST, 50 | "pil_bilinear": Image.BILINEAR, 51 | "pil_bicubic": Image.BICUBIC, 52 | "pil_box": Image.BOX, 53 | "pil_hamming": Image.HAMMING, 54 | "pil_lanczos": Image.LANCZOS, 55 | }[interpolation] 56 | 57 | self.prompt_list=[] 58 | with open(prompt_path) as f: 59 | line=f.readline() 60 | while line: 61 | line=line.strip('\n') 62 | self.prompt_list.append(line) 63 | line=f.readline() 64 | 65 | print(f"REDS has {len(self)} samples!!") 66 | 67 | def __len__(self): 68 | return int(self.sizex * self.sample_weight) 69 | 70 | def __getitem__(self, index): 71 | if self.raw_image: 72 | return self.get_raw_image(index) 73 | else: 74 | return self.get_processed_image(index) 75 | 76 | def get_processed_image(self, index): 77 | if self.sample_weight >= 1: 78 | index_ = index % self.sizex 79 | else: 80 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 81 | 82 | inp_path = self.inp_filenames[index_] 83 | tar_path = self.tar_filenames[index_] 84 | 85 | inp_img = Image.open(inp_path) 86 | tar_img = Image.open(tar_path) 87 | 88 | width, height = inp_img.size 89 | tar_width, tar_height = tar_img.size 90 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 91 | aspect_ratio = float(width) / float(height) 92 | if width < height: 93 | new_width = self.size 94 | new_height = int(self.size / aspect_ratio) 95 | else: 96 | new_height = self.size 97 | new_width = int(self.size * aspect_ratio) 98 | inp_img = inp_img.resize((new_width, new_height), self.interpolation) 99 | tar_img = tar_img.resize((new_width, new_height), self.interpolation) 100 | 101 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1) 102 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32)) 103 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1) 104 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32)) 105 | crop = torchvision.transforms.RandomCrop(self.size) 106 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) 107 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2) 108 | 109 | prompt = random.choice(self.prompt_list) 110 | if self.instruct: 111 | prompt = "Image Deblurring: " + prompt 112 | 113 | if self.check: 114 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt)) 115 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt)) 116 | 117 | def get_raw_image(self, index): 118 | if self.sample_weight >= 1: 119 | index_ = index % self.sizex 120 | else: 121 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 122 | 123 | inp_path = self.inp_filenames[index_] 124 | tar_path = self.tar_filenames[index_] 125 | 126 | inp_img = Image.open(inp_path).convert('RGB') 127 | tar_img = Image.open(tar_path).convert('RGB') 128 | 129 | width, height = inp_img.size 130 | tar_width, tar_height = tar_img.size 131 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 132 | 133 | prompt = random.choice(self.prompt_list) 134 | if self.instruct: 135 | prompt = "Image Deblurring: " + prompt 136 | 137 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 138 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 139 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]} -------------------------------------------------------------------------------- /locals/datasets/image_edit/low_level/sidd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InstructDiffusion 3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) 4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn) 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | import torch 11 | from PIL import Image 12 | import torchvision.transforms.functional as TF 13 | from pdb import set_trace as stx 14 | import random 15 | import cv2 16 | from PIL import Image 17 | import torchvision 18 | from constants import ALL_IMG_TOKENS_STR 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 22 | 23 | 24 | class SIDD(Dataset): 25 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False, 26 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_denoise.txt'): 27 | super(SIDD, self).__init__() 28 | 29 | # inp_files = sorted(os.listdir(os.path.join(path, split, 'input'))) 30 | # tar_files = sorted(os.listdir(os.path.join(path, split, 'gt'))) 31 | 32 | # self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)] 33 | # self.tar_filenames = [os.path.join(path, split, 'gt', x) for x in tar_files if is_image_file(x)] 34 | 35 | filenames_splits = sorted(os.listdir(os.path.join(path))) 36 | 37 | self.inp_filenames = [os.path.join(path, d, x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, d))) if is_image_file(x) and 'NOISY' in x] 38 | self.tar_filenames = [os.path.join(path, d, x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, d))) if is_image_file(x) and 'GT' in x] 39 | 40 | self.size = size 41 | self.flip_prob = flip_prob 42 | self.sample_weight = sample_weight 43 | self.instruct = instruct 44 | self.check = check 45 | self.raw_image = raw_image 46 | self.sizex = len(self.tar_filenames) # get the size of target 47 | 48 | self.interpolation = { 49 | "cv_nearest": cv2.INTER_NEAREST, 50 | "cv_bilinear": cv2.INTER_LINEAR, 51 | "cv_bicubic": cv2.INTER_CUBIC, 52 | "cv_area": cv2.INTER_AREA, 53 | "cv_lanczos": cv2.INTER_LANCZOS4, 54 | "pil_nearest": Image.NEAREST, 55 | "pil_bilinear": Image.BILINEAR, 56 | "pil_bicubic": Image.BICUBIC, 57 | "pil_box": Image.BOX, 58 | "pil_hamming": Image.HAMMING, 59 | "pil_lanczos": Image.LANCZOS, 60 | }[interpolation] 61 | 62 | # prompt_path='dataset/prompt/prompt_denoise.txt' 63 | self.prompt_list=[] 64 | with open(prompt_path) as f: 65 | line=f.readline() 66 | while line: 67 | line=line.strip('\n') 68 | self.prompt_list.append(line) 69 | line=f.readline() 70 | print(f"SIDD has {len(self)} samples!!") 71 | 72 | def __len__(self): 73 | return int(self.sizex * self.sample_weight) 74 | 75 | def __getitem__(self, index): 76 | if self.raw_image: 77 | return self.get_raw_image(index) 78 | else: 79 | return self.get_processed_image(index) 80 | 81 | def get_processed_image(self, index): 82 | if self.sample_weight >= 1: 83 | index_ = index % self.sizex 84 | else: 85 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 86 | 87 | inp_path = self.inp_filenames[index_] 88 | tar_path = self.tar_filenames[index_] 89 | 90 | inp_img = Image.open(inp_path) 91 | tar_img = Image.open(tar_path) 92 | 93 | width, height = inp_img.size 94 | tar_width, tar_height = tar_img.size 95 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 96 | 97 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1) 98 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32)) 99 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1) 100 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32)) 101 | crop = torchvision.transforms.RandomCrop(self.size) 102 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) 103 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2) 104 | 105 | prompt = random.choice(self.prompt_list) 106 | if self.instruct: 107 | prompt = "Image Denoising: " + prompt 108 | 109 | if self.check: 110 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt)) 111 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt)) 112 | 113 | def get_raw_image(self, index): 114 | if self.sample_weight >= 1: 115 | index_ = index % self.sizex 116 | else: 117 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) 118 | 119 | inp_path = self.inp_filenames[index_] 120 | tar_path = self.tar_filenames[index_] 121 | 122 | inp_img = Image.open(inp_path).convert('RGB') 123 | tar_img = Image.open(tar_path).convert('RGB') 124 | 125 | width, height = inp_img.size 126 | tar_width, tar_height = tar_img.size 127 | assert tar_width == width and tar_height == height, "Input and target image mismatch" 128 | 129 | prompt = random.choice(self.prompt_list) 130 | if self.instruct: 131 | prompt = "Image Denoising: " + prompt 132 | 133 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 134 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 135 | return {'input_images': [inp_img, tar_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]} -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/color_list_train_small.txt: -------------------------------------------------------------------------------- 1 | Red 纯红 #FF0000 255,0,0 2 | 3 | Purple 紫色 #800080 128,0,128 4 | 5 | Blue 纯蓝 #0000FF 0,0,255 6 | 7 | Green 纯绿 #008000 0,128,0 8 | 9 | Yellow 纯黄 #FFFF00 255,255,0 10 | 11 | White 纯白 #FFFFFF 255,255,255 12 | 13 | Black 纯黑 #000000 0,0,0 14 | 15 | Gray 灰色 #808080 128,128,128 16 | 17 | 18 | -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/prompt_deblur.txt: -------------------------------------------------------------------------------- 1 | Sharpen this blurry image 2 | Increase the sharpness of this unclear photo 3 | Correct the lack of focus in this misty picture 4 | Heighten the definition of this smeared image 5 | Clear up this fuzzy picture 6 | Refine this indistinct photograph 7 | Improve the focus of this hazy image 8 | Amend the softness of this out-of-focus photograph 9 | Polish the murkiness of this low-definition photo 10 | Rectify the vagueness of this blurred image -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/prompt_denoise.txt: -------------------------------------------------------------------------------- 1 | Remove noise from this image 2 | Eliminate the noise in this picture 3 | Purify this photo by removing noise 4 | Clear up the image by filtering out noise 5 | Eradicate the noise from this photograph 6 | Minimize the noise present in this picture 7 | Cancel out the noise within this image 8 | Clean this photo by discarding the noise 9 | Suppress the noise in this visual representation 10 | Rectify the noise interference in this image -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/prompt_dewatermark.txt: -------------------------------------------------------------------------------- 1 | Remove watermark from this picture 2 | Erase the watermark from this photograph. 3 | Extract the watermark from this image. 4 | Take out the watermark overlay from this photo. 5 | Wipe off the watermark imprint on this image. 6 | Detach the watermark from this visual representation. 7 | Get rid of the watermarking on this picture. 8 | Withdraw the watermark applied to this photograph. 9 | Clean up this image by deleting the watermark. 10 | Unmark this photo by removing the watermark. -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/prompt_pose.txt: -------------------------------------------------------------------------------- 1 | Circle the {joint} of the people with the color {color}, 2 | Use the {color} color to draw circles around the {joint} of the people, 3 | Make {color} circles around the {joint} of the people, 4 | Put {color} circles on the {joint} of the people, 5 | Draw {color} circles over the {joint} of the people, 6 | Surround the {joint} of the people with {color} circles, 7 | Use the color {color} to make circles on the {joint} of the people, 8 | Mark the {joint} of the people with {color} circles, 9 | Create {color} circles around the {joint} of the people, 10 | Use the color {color} to encircle the {joint} of the people, -------------------------------------------------------------------------------- /locals/datasets/image_edit/prompt/prompt_seg.txt: -------------------------------------------------------------------------------- 1 | Mark the pixels of {object} in {color} and leave the rest unchanged. 2 | Color the {object}'s pixels in {color}, keeping the remaining pixels unaltered. 3 | Apply {color} to the pixels of {object} while maintaining the current state of other pixels. 4 | Assign {color} to the pixels belonging to {object}, preserving the rest as they are. 5 | For {object}, set its pixels to {color} and let the others remain the same. 6 | Modify the pixels of {object} to {color} without affecting any other pixels. 7 | Set the {object} pixels to {color} and keep the other pixels in their original state. 8 | Update the pixels of {object} to {color}, but leave the other pixels untouched. 9 | Fill in the pixels of {object} with {color}, retaining the existing colors of the remaining pixels. 10 | Change the {object} pixels to {color}, while keeping the other pixels constant. 11 | Paint the pixels of {object} in {color} and maintain the current appearance of the other pixels. -------------------------------------------------------------------------------- /locals/datasets/image_edit/seg/grefcoco_seg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InstructDiffusion 3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) 4 | # Modified by Binxin Yang (tennyson@mail.ustc.edu.cn) 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import annotations 8 | 9 | import os 10 | import random 11 | import copy 12 | import json 13 | import math 14 | from pathlib import Path 15 | from typing import Any 16 | 17 | import numpy as np 18 | import torch 19 | import torchvision 20 | from einops import rearrange 21 | from PIL import Image 22 | from torch.utils.data import Dataset 23 | 24 | from .grefcoco import G_REFER 25 | from constants import ALL_IMG_TOKENS_STR 26 | 27 | 28 | class GrefCOCODataset(Dataset): 29 | def __init__( 30 | self, 31 | path: str, 32 | split: str = "train", 33 | min_resize_res: int = 256, 34 | max_resize_res: int = 256, 35 | crop_res: int = 256, 36 | flip_prob: float = 0.0, 37 | transparency: float = 0.0, 38 | test: bool = False, 39 | image_path: str = None 40 | ): 41 | assert split in ("train", "val", "test") 42 | self.path = path 43 | self.min_resize_res = min_resize_res 44 | self.max_resize_res = max_resize_res 45 | self.crop_res = crop_res 46 | self.flip_prob = flip_prob 47 | self.G_ref_dataset=G_REFER(data_root=path) 48 | self.IMAGE_DIR = os.path.join(image_path, 'train2014') 49 | self.list_ref=self.G_ref_dataset.getRefIds(split=split) 50 | self.transparency = transparency 51 | self.test = test 52 | 53 | seg_diverse_prompt_path = 'datasets/image_edit/prompt/prompt_seg.txt' 54 | self.seg_diverse_prompt_list=[] 55 | with open(seg_diverse_prompt_path) as f: 56 | line=f.readline() 57 | while line: 58 | line=line.strip('\n') 59 | self.seg_diverse_prompt_list.append(line) 60 | line=f.readline() 61 | 62 | color_list_file_path='datasets/image_edit/prompt/color_list_train_small.txt' 63 | self.color_list=[] 64 | with open(color_list_file_path) as f: 65 | line = f.readline() 66 | while line: 67 | line_split = line.strip('\n').split(" ") 68 | if len(line_split)>1: 69 | temp = [] 70 | for i in range(4): 71 | temp.append(line_split[i]) 72 | self.color_list.append(temp) 73 | line = f.readline() 74 | 75 | def __len__(self) -> int: 76 | return len(self.list_ref) 77 | 78 | def _augmentation_new(self, image, label): 79 | 80 | # Cropping 81 | h, w = label.shape 82 | if h > w: 83 | start_h = random.randint(0, h - w) 84 | end_h = start_h + w 85 | image = image[start_h:end_h] 86 | label = label[start_h:end_h] 87 | elif h < w: 88 | start_w = random.randint(0, w - h) 89 | end_w = start_w + h 90 | image = image[:, start_w:end_w] 91 | label = label[:, start_w:end_w] 92 | else: 93 | pass 94 | image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS) 95 | image = np.asarray(image, dtype=np.uint8) 96 | label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST) 97 | label = np.asarray(label, dtype=np.int64) 98 | return image, label 99 | 100 | def __getitem__(self, i: int) -> dict[str, Any]: 101 | 102 | ref_ids = self.list_ref[i] 103 | ref = self.G_ref_dataset.loadRefs(ref_ids)[0] 104 | sentences = random.choice(ref['sentences'])['sent'] 105 | 106 | prompt = random.choice(self.seg_diverse_prompt_list) 107 | 108 | color = random.choice(self.color_list) 109 | color_name = color[0] 110 | prompt = prompt.format(color=color_name.lower(), object=sentences.lower()) 111 | 112 | R, G, B = color[3].split(",") 113 | R = int(R) 114 | G = int(G) 115 | B = int(B) 116 | 117 | image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name'] 118 | image_path = os.path.join(self.IMAGE_DIR,image_name) 119 | mask = self.G_ref_dataset.getMaskByRef(ref=ref,merge=True)['mask'] 120 | 121 | image = Image.open(image_path).convert("RGB") 122 | image = np.asarray(image) 123 | 124 | # image, mask = self._augmentation_new(image,mask) 125 | 126 | mask = (mask == 1) 127 | 128 | image_0 = Image.fromarray(image) 129 | image_1 = copy.deepcopy(image) 130 | image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R 131 | image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G 132 | image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B 133 | image_1 = Image.fromarray(image_1) 134 | 135 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() 136 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 137 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 138 | 139 | 140 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 141 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 142 | 143 | return {'input_images': [image_0], 'output_images': [image_1], 144 | 'output_cond_images': [image_0], 'conversation': sources, 'image_label_masks': [0, 1]} -------------------------------------------------------------------------------- /locals/datasets/image_edit/seg/refcoco_seg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InstructDiffusion 3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) 4 | # Modified by Binxin Yang (tennyson@mail.ustc.edu.cn) 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import annotations 8 | 9 | import os 10 | import random 11 | import copy 12 | import json 13 | import math 14 | from pathlib import Path 15 | from typing import Any 16 | 17 | import numpy as np 18 | import torch 19 | import torchvision 20 | from einops import rearrange 21 | from PIL import Image 22 | from torch.utils.data import Dataset 23 | 24 | from .refcoco import REFER 25 | from constants import ALL_IMG_TOKENS_STR 26 | 27 | 28 | class RefCOCODataset(Dataset): 29 | def __init__( 30 | self, 31 | path: str, 32 | split: str = "train", 33 | min_resize_res: int = 256, 34 | max_resize_res: int = 256, 35 | crop_res: int = 256, 36 | flip_prob: float = 0.0, 37 | transparency: float = 0.0, 38 | test: bool = False, 39 | image_path: str = None, 40 | ): 41 | assert split in ("train", "val", "test") 42 | self.path = path 43 | self.min_resize_res = min_resize_res 44 | self.max_resize_res = max_resize_res 45 | self.crop_res = crop_res 46 | self.flip_prob = flip_prob 47 | self.G_ref_dataset=REFER(data_root=path) 48 | self.IMAGE_DIR = os.path.join(image_path, 'train2014') 49 | self.list_ref=self.G_ref_dataset.getRefIds(split=split) 50 | self.transparency = transparency 51 | self.test = test 52 | 53 | seg_diverse_prompt_path = 'datasets/image_edit/prompt/prompt_seg.txt' 54 | self.seg_diverse_prompt_list=[] 55 | with open(seg_diverse_prompt_path) as f: 56 | line=f.readline() 57 | while line: 58 | line=line.strip('\n') 59 | self.seg_diverse_prompt_list.append(line) 60 | line=f.readline() 61 | 62 | color_list_file_path='datasets/image_edit/prompt/color_list_train_small.txt' 63 | self.color_list=[] 64 | with open(color_list_file_path) as f: 65 | line = f.readline() 66 | while line: 67 | line_split = line.strip('\n').split(" ") 68 | if len(line_split)>1: 69 | temp = [] 70 | for i in range(4): 71 | temp.append(line_split[i]) 72 | self.color_list.append(temp) 73 | line = f.readline() 74 | 75 | def __len__(self) -> int: 76 | return len(self.list_ref) 77 | 78 | def _augmentation_new(self, image, label): 79 | 80 | # Cropping 81 | h, w = label.shape 82 | if h > w: 83 | start_h = random.randint(0, h - w) 84 | end_h = start_h + w 85 | image = image[start_h:end_h] 86 | label = label[start_h:end_h] 87 | elif h < w: 88 | start_w = random.randint(0, w - h) 89 | end_w = start_w + h 90 | image = image[:, start_w:end_w] 91 | label = label[:, start_w:end_w] 92 | else: 93 | pass 94 | image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS) 95 | image = np.asarray(image, dtype=np.uint8) 96 | label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST) 97 | label = np.asarray(label, dtype=np.int64) 98 | return image, label 99 | 100 | def __getitem__(self, i: int) -> dict[str, Any]: 101 | 102 | ref_ids = self.list_ref[i] 103 | ref = self.G_ref_dataset.loadRefs(ref_ids)[0] 104 | sentences = random.choice(ref['sentences'])['sent'] 105 | 106 | prompt = random.choice(self.seg_diverse_prompt_list) 107 | 108 | color = random.choice(self.color_list) 109 | color_name = color[0] 110 | prompt = prompt.format(color=color_name.lower(), object=sentences.lower()) 111 | 112 | R, G, B = color[3].split(",") 113 | R = int(R) 114 | G = int(G) 115 | B = int(B) 116 | 117 | image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name'] 118 | image_path = os.path.join(self.IMAGE_DIR,image_name) 119 | mask = self.G_ref_dataset.getMask(ref=ref)['mask'] 120 | 121 | image = Image.open(image_path).convert("RGB") 122 | image = np.asarray(image) 123 | 124 | image, mask = self._augmentation_new(image,mask) 125 | 126 | mask = (mask == 1) 127 | 128 | image_0 = Image.fromarray(image) 129 | image_1 = copy.deepcopy(image) 130 | image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R 131 | image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G 132 | image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B 133 | image_1 = Image.fromarray(image_1) 134 | 135 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() 136 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 137 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 138 | 139 | # image_1 = Image.fromarray(image_1) 140 | sources = [{'from': 'human', 'value': '{}: .'.format(prompt)}, 141 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 142 | 143 | return {'input_images': [image_0], 'output_images': [image_1], 144 | 'output_cond_images': [image_0], 'conversation': sources, 'image_label_masks': [0, 1]} -------------------------------------------------------------------------------- /locals/datasets/multimodal_tasks/llava_R.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import math 3 | import random 4 | from PIL import Image 5 | # 使用 read_parquet 加载parquet文件 6 | from pandas import read_parquet 7 | import io 8 | 9 | from constants import ALL_IMG_TOKENS_STR 10 | 11 | class LlavaRInstructDataset(Dataset): 12 | def __init__(self, 13 | path: str, 14 | instruct: bool = False, 15 | min_resize_res: int = 256, 16 | max_resize_res: int = 256, 17 | crop_res: int = 256, 18 | flip_prob: float = 0.5, 19 | sample_weight: float = 1.0, 20 | check: bool = False, 21 | output_mode: str = 'conversation', 22 | shuffle: bool = False, 23 | raw_image: bool = False, 24 | inference: bool = False, 25 | min_size: int = 50, 26 | **kwargs): 27 | self.path = path 28 | self.instruct = instruct 29 | self.inference = inference 30 | self.meta = read_parquet(path) 31 | self.min_resize_res = min_resize_res 32 | self.max_resize_res = max_resize_res 33 | self.crop_res = crop_res 34 | self.check = check 35 | self.raw_image = raw_image 36 | self.output_mode = output_mode 37 | self.shuffle = shuffle 38 | self.min_size = min_size 39 | self.flip_prob = flip_prob 40 | self.sample_weight = sample_weight 41 | print(f"LlavaR Instruct has {len(self)} samples!!") 42 | 43 | def __len__(self): 44 | return int(len(self.meta) * self.sample_weight) 45 | 46 | def get_sampler_index(self,i): 47 | if self.sample_weight >= 1: 48 | i = i % len(self.meta) 49 | else: 50 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 51 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 52 | return i 53 | 54 | def __getitem__(self, i): 55 | assert self.output_mode == 'conversation' 56 | ini_i = i 57 | i = self.get_sampler_index(i) 58 | item = self.meta.iloc[i] 59 | image_bytes = item['image']['bytes'] 60 | tgt_img = [Image.open(io.BytesIO(image_bytes))] 61 | new_conversation = [] 62 | assert len(item['user_texts']) == len(item['bot_texts']) 63 | for i in range(len(item['user_texts'])): 64 | if i == 0: 65 | new_conversation.append({'from':"human",'value': ALL_IMG_TOKENS_STR + item['user_texts'][i]}) 66 | else: 67 | new_conversation.append({'from':"human",'value':item['user_texts'][i]}) 68 | new_conversation.append({'from':"gpt",'value':item['bot_texts'][i]}) 69 | 70 | return {'input_images': tgt_img, 'conversation': new_conversation,'id':f'llava_r_{ini_i}','image_label_masks': [0]} 71 | 72 | if __name__ == '__main__': 73 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LLaVAR-Instruct-16K/data/train-00000-of-00001-890199abde0ec4ff.parquet' 74 | test = LlavaRInstructDataset(path=test_path) 75 | print(len(test)) 76 | print(test[1000]) 77 | 78 | -------------------------------------------------------------------------------- /locals/datasets/multimodal_tasks/lvis_instruct4v.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | from constants import ALL_IMG_TOKENS_STR 8 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset 9 | from copy import deepcopy 10 | 11 | class LVISinstruct4vDataset(SingleImageDataset): 12 | def __init__(self, 13 | path: str, 14 | image_folder: str, 15 | instruct: bool = False, 16 | min_resize_res: int = 256, 17 | max_resize_res: int = 256, 18 | crop_res: int = 256, 19 | flip_prob: float = 0.5, 20 | sample_weight: float = 1.0, 21 | check: bool = False, 22 | output_mode: str = 'text', 23 | shuffle: bool = False, 24 | raw_image: bool = False, 25 | inference: bool = False, 26 | min_size: int = 50, 27 | **kwargs): 28 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs) 29 | print(f"LVIS-instruct4v has {len(self)} samples!!") 30 | 31 | def __len__(self): 32 | return int(len(self.meta) * self.sample_weight) 33 | 34 | def __getitem__(self, i): 35 | assert self.output_mode == 'conversation' 36 | 37 | i = self.get_sampler_index(i) 38 | 39 | item = self.meta[i] 40 | if 'image' in item: 41 | image_fn = item['image'] 42 | dirname = image_fn.split('/')[0] 43 | image_fn = '/'.join(image_fn.split('/')[1:]) 44 | all_sub_images = [] 45 | if dirname == 'coco': 46 | dirname = 'COCO2017' 47 | else: 48 | raise ValueError 49 | 50 | tgt_img = [Image.open(os.path.join(self.image_folder, dirname, image_fn)).convert('RGB')] 51 | 52 | new_conversation = deepcopy(item['conversations']) 53 | for conv in new_conversation: 54 | if '' in conv['value']: 55 | conv['value'] = conv['value'].replace('', ALL_IMG_TOKENS_STR) 56 | return {'input_images': tgt_img + all_sub_images, 'conversation': new_conversation,'id':'LVIS_'+item['id'],'image_label_masks': [0]} 57 | else: 58 | tgt_img = None 59 | return {'conversation': item['conversations'],'id':'LVIS_'+item['id']} 60 | 61 | if __name__ == '__main__': 62 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS-Instruct4V/lvis_instruct4v_220k.json' 63 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/' 64 | test = LVISinstruct4vDataset(path=test_path, image_folder=image_folder, output_mode = 'conversation') 65 | print(len(test)) 66 | print(test[1000]) 67 | 68 | 69 | -------------------------------------------------------------------------------- /locals/datasets/multimodal_tasks/m3it.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | from typing import List 8 | from constants import * 9 | import re, copy 10 | from base64 import b64encode, b64decode 11 | import io 12 | from datasets import load_dataset, concatenate_datasets 13 | 14 | def byte2image(byte_data): 15 | """ 16 | convert byte to PIL image 17 | """ 18 | if isinstance(byte_data, str): 19 | byte_data = b64decode(byte_data) 20 | image = Image.open(io.BytesIO(byte_data)) 21 | return image 22 | 23 | class M3ITDataset(Dataset): 24 | 25 | def __init__(self, 26 | path: str, 27 | dataset_names: List[str], 28 | split: str = 'train', 29 | crop_res: int = 256, 30 | flip_prob: float = 0.5, 31 | sample_weight: float = 1.0, 32 | check: bool = False, 33 | output_mode: str = 'text', 34 | shuffle: bool = False, 35 | raw_image: bool = False, 36 | inference: bool = False, 37 | min_size: int = 50, 38 | **kwargs): 39 | # load from json ../datasets/gqa-inpaint/meta_info.json 40 | self.path = path 41 | self.inference = inference 42 | tmp_ds = [] 43 | for name in dataset_names: 44 | print(name) 45 | tmp_ds.append(load_dataset(os.path.join(path, name))[split]) 46 | self.meta = concatenate_datasets(tmp_ds) 47 | self.crop_res = crop_res 48 | self.check = check 49 | self.raw_image = raw_image 50 | self.output_mode = output_mode 51 | self.shuffle = shuffle 52 | self.min_size = min_size 53 | 54 | self.flip_prob = flip_prob 55 | self.sample_weight = sample_weight 56 | print(f"LLaVA Academic has {len(self)} samples!!") 57 | 58 | def __len__(self): 59 | return int(len(self.meta) * self.sample_weight) 60 | 61 | def __getitem__(self, i): 62 | assert self.output_mode == 'conversation' 63 | 64 | if self.sample_weight >= 1: 65 | i = i % len(self.meta) 66 | else: 67 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 68 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 69 | 70 | item = self.meta[i] 71 | image = item['image_base64_str'] 72 | image = byte2image(image[0]).convert('RGB') 73 | 74 | if len(item['inputs']): 75 | prompt = item['instruction'] + ' ' + item['inputs'] 76 | else: 77 | prompt = item['instruction'] 78 | 79 | prompt = ALL_IMG_TOKENS_STR + ' ' + prompt 80 | conversation = [ 81 | {'from': 'human', 'value': prompt}, 82 | {'from': 'gpt', 'value': item['outputs']} 83 | ] 84 | return {'input_images': [image], 'conversation': conversation, 'image_label_masks': [0]} 85 | 86 | -------------------------------------------------------------------------------- /locals/datasets/multimodal_tasks/sharegpt4v.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset 4 | from copy import deepcopy 5 | from constants import ALL_IMG_TOKENS_STR 6 | 7 | class Sharegpt4vDataset(SingleImageDataset): 8 | def __init__(self, 9 | path: str, 10 | image_folder: str, 11 | llava_pretrain_folder: str, 12 | coco_img_folder: str, 13 | instruct: bool = False, 14 | min_resize_res: int = 256, 15 | max_resize_res: int = 256, 16 | crop_res: int = 256, 17 | flip_prob: float = 0.5, 18 | sample_weight: float = 1.0, 19 | check: bool = False, 20 | output_mode: str = 'text', 21 | shuffle: bool = False, 22 | raw_image: bool = False, 23 | inference: bool = False, 24 | min_size: int = 50, 25 | **kwargs): 26 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs) 27 | self.llava_pretrain_folder = llava_pretrain_folder 28 | self.coco_img_folder = coco_img_folder 29 | print(f"Sharegpt4vDataset has {len(self)} samples!!") 30 | 31 | def __len__(self): 32 | return int(len(self.meta) * self.sample_weight) 33 | 34 | def __getitem__(self, i): 35 | assert self.output_mode == 'conversation' 36 | 37 | i = self.get_sampler_index(i) 38 | 39 | item = self.meta[i] 40 | if 'image' in item: 41 | image_fn = item['image'] 42 | sub_dir = image_fn.split('/')[0] 43 | 44 | if sub_dir == 'llava': 45 | image_folder = self.llava_pretrain_folder 46 | image_fn = '/'.join(image_fn.split('/')[3:]) 47 | elif sub_dir == 'coco': 48 | image_folder = self.coco_img_folder 49 | image_fn = '/'.join(image_fn.split('/')[1:]) 50 | else: 51 | image_folder = self.image_folder 52 | tgt_img = [Image.open(os.path.join(image_folder, image_fn)).convert('RGB')] 53 | 54 | new_conversation = deepcopy(item['conversations']) 55 | for conv in new_conversation: 56 | if '' in conv['value']: 57 | conv['value'] = conv['value'].replace('', ALL_IMG_TOKENS_STR) 58 | 59 | return {'input_images': tgt_img, 'conversation': new_conversation,'id':'Sharegpt4v_'+item['id'],'image_path':item['image'],'image_label_masks': [0]} 60 | else: 61 | tgt_img = None 62 | return {'conversation': item['conversations'],'id':'Sharegpt4v_'+item['id'],'image_path':item['image']} 63 | 64 | if __name__ == '__main__': 65 | from tqdm import tqdm 66 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json' 67 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/ShareGPT4V/data/' 68 | coco_image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/COCO2017' 69 | llava_pretrain_folder = '/mnt/bn/luoruipu-disk/meta_data/pretrain_data/LLaVA-Pretrain/' 70 | test = Sharegpt4vDataset(path=test_path, image_folder=image_folder, llava_pretrain_folder=llava_pretrain_folder,coco_img_folder=coco_image_folder, output_mode = 'conversation') 71 | print(len(test)) 72 | for i in tqdm(range(len(test))): 73 | test[i] 74 | -------------------------------------------------------------------------------- /locals/datasets/multimodal_tasks/single_image_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | import os 6 | from PIL import Image 7 | 8 | class SingleImageDataset(Dataset): 9 | 10 | def __init__(self, 11 | path: str, 12 | image_folder: str, 13 | instruct: bool = False, 14 | min_resize_res: int = 256, 15 | max_resize_res: int = 256, 16 | crop_res: int = 256, 17 | flip_prob: float = 0.5, 18 | sample_weight: float = 1.0, 19 | check: bool = False, 20 | output_mode: str = 'text', 21 | shuffle: bool = False, 22 | raw_image: bool = False, 23 | inference: bool = False, 24 | min_size: int = 50, 25 | **kwargs): 26 | # load from json ../datasets/gqa-inpaint/meta_info.json 27 | self.path = path 28 | self.instruct = instruct 29 | self.inference = inference 30 | self.meta = js.load(open(path)) 31 | self.image_folder = image_folder 32 | self.min_resize_res = min_resize_res 33 | self.max_resize_res = max_resize_res 34 | self.crop_res = crop_res 35 | self.check = check 36 | self.raw_image = raw_image 37 | self.output_mode = output_mode 38 | self.shuffle = shuffle 39 | self.min_size = min_size 40 | self.flip_prob = flip_prob 41 | self.sample_weight = sample_weight 42 | 43 | def __len__(self): 44 | return int(len(self.meta) * self.sample_weight) 45 | 46 | def get_sampler_index(self,i): 47 | if self.sample_weight >= 1: 48 | i = i % len(self.meta) 49 | else: 50 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 51 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 52 | return i 53 | 54 | def __getitem__(self, i): 55 | raise NotImplementedError -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_captioning_long.txt: -------------------------------------------------------------------------------- 1 | Write a description of the image, capturing its main components, the relationships between them, and any notable details. 2 | Create a caption that accurately describes the elements in the image provided. 3 | Write a comprehensive, yet brief, description of the image. 4 | Describe the image in a clear and detailed manner. 5 | For the given image, describe the visual content, try to include every components. 6 | Generate a detailed caption for the picture. 7 | Write a detailed and informative description that highlights the primary subjects and actions occurring in the given image. 8 | Write a clear description of the image, make sure the key features are well covered. 9 | Offer a comprehensive explanation of the picture presented. -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_captioning_short.txt: -------------------------------------------------------------------------------- 1 | Describe the image briefly. 2 | Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details. 3 | Create a concise caption that accurately describes the main elements in the image provided. 4 | Write a brief, yet comprehensive, description of the image. 5 | Describe the image in a clear and concise manner. 6 | For the given image, provide a one-sentence summary that captures the most important details. 7 | Generate a short caption for the picture. 8 | Write a short and informative description that highlights the primary subjects and actions occurring in the given image. 9 | Provide a concise and informative caption for the image, focusing on the primary subjects. 10 | Write a clear description of the image, make sure the key features are well covered. 11 | Offer a succinct explanation of the picture presented. -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_kosmosg.txt: -------------------------------------------------------------------------------- 1 | Help me to generate this image: {desc} 2 | Generate image according to the following information: {desc} 3 | Based on the provided description and corresponding images of objects: {desc}, generate a complete image. 4 | Create an image of {desc} 5 | I will provide you with descriptions of the images and images of the objects. help me reconstruct the entire picture: {desc} -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_ref_i2t.txt: -------------------------------------------------------------------------------- 1 | Briefly describe the region {region}. 2 | Provide a description for the region {region} to distinguish it with other regions. 3 | Express what is showed in this region: {region}. 4 | What object is in this image segment {region}. 5 | How will you depict this visual region {region} with natural language? 6 | What is in {region}? 7 | Can you describe the region {region} for me? -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_ref_t2i.txt: -------------------------------------------------------------------------------- 1 | Locate the region in the image describing: {phrase}. 2 | Find the region or regions that correspond to {phrase}. 3 | Can you find {phrase} in the image? 4 | Please detect {phrase} in this picture. 5 | In this image, which object matches the description "{phrase}"? 6 | Which region in the figure best corresponds to the following description: {phrase}? 7 | Which part of the illustration best matches {phrase}? 8 | Please locate the area described as {phrase} in the picture. 9 | Identify the region in the image referred as {phrase}. 10 | Please help me find {phrase} in the picture. -------------------------------------------------------------------------------- /locals/datasets/prompts/prompt_txt2img.txt: -------------------------------------------------------------------------------- 1 | generate an image with caption: 2 | can you give me the image with caption: 3 | help me to generate this image: 4 | generate image with according to caption: 5 | according to caption, generate image: 6 | an image with caption: 7 | can you visualize this caption: 8 | Please generate an image corresponding to my description: 9 | -------------------------------------------------------------------------------- /locals/datasets/text/sharegpt.py: -------------------------------------------------------------------------------- 1 | from locals.datasets.text.text_data_base import TextDataset 2 | 3 | class ShareGPTDataset(TextDataset): 4 | def __init__(self, 5 | path: str, 6 | instruct: bool = False, 7 | sample_weight: float = 1, 8 | output_mode: str = 'conversation', 9 | shuffle: bool = False, 10 | inference: bool = False, 11 | **kwargs): 12 | path = [path] 13 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs) 14 | print(f"ShareGPTDataset has {len(self)} samples!!") 15 | 16 | def __getitem__(self, i): 17 | assert self.output_mode == 'conversation' 18 | i = self.get_sampler_index(i) 19 | item = self.meta[i] 20 | return {'conversation': item['conversations'],'id':item['id']} 21 | 22 | class ShareGPTCodeDataset(TextDataset): 23 | def __init__(self, 24 | path: str, 25 | instruct: bool = False, 26 | sample_weight: float = 1, 27 | output_mode: str = 'conversation', 28 | shuffle: bool = False, 29 | inference: bool = False, 30 | **kwargs): 31 | path = [path] 32 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs) 33 | print(f"ShareGPTCodeDataset has {len(self)} samples!!") 34 | 35 | def __getitem__(self, i): 36 | assert self.output_mode == 'conversation' 37 | i = self.get_sampler_index(i) 38 | item = self.meta[i] 39 | return {'conversation': item['conversations'],'id':'sharegpt_code_'+item['id']} 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/text-dataset/sharegpt_184k_no_repeat.json' 45 | test = ShareGPTDataset(path=test_path) 46 | print(len(test)) 47 | print(test[100000]) -------------------------------------------------------------------------------- /locals/datasets/text/text_data_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json as js 3 | import math 4 | import random 5 | from tqdm import tqdm 6 | 7 | class TextDataset(Dataset): 8 | 9 | def __init__(self, 10 | pathlist: list, 11 | instruct: bool = False, 12 | sample_weight: float = 1.0, 13 | output_mode: str = 'text', 14 | shuffle: bool = False, 15 | inference: bool = False, 16 | **kwargs): 17 | # load from json ../datasets/gqa-inpaint/meta_info.json 18 | self.pathlist = pathlist 19 | self.instruct = instruct 20 | self.inference = inference 21 | self.meta = [] 22 | for path in tqdm(pathlist): 23 | if path.endswith('json'): 24 | self.meta += js.load(open(path)) 25 | elif path.endswith('jsonl'): 26 | with open(path) as f: 27 | for line in f: 28 | try: 29 | self.meta.append(js.loads(line)) 30 | except: 31 | continue 32 | else: 33 | raise ValueError('json or jsonl file type is supported. ') 34 | self.output_mode = output_mode 35 | self.shuffle = shuffle 36 | self.sample_weight = sample_weight 37 | 38 | def __len__(self): 39 | return int(len(self.meta) * self.sample_weight) 40 | 41 | def get_sampler_index(self,i): 42 | if self.sample_weight >= 1: 43 | i = i % len(self.meta) 44 | else: 45 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 46 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 47 | return i 48 | 49 | def __getitem__(self, i): 50 | raise NotImplementedError -------------------------------------------------------------------------------- /locals/datasets/text/txt_cot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | import json 5 | import math 6 | from pathlib import Path 7 | from tkinter import E 8 | from typing import Any 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | import pandas as pd 14 | from einops import rearrange 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from utils.util import byte2image 18 | from ..utils.box_utils import * 19 | from datasets import load_dataset 20 | 21 | from constants import * 22 | 23 | class CoTCollectionDataset(Dataset): 24 | def __init__(self, 25 | path:str = None, 26 | sample_weight: float=1.0, 27 | ): 28 | self.meta = load_dataset(path)['train'] 29 | self.sample_weight = sample_weight 30 | 31 | def __len__(self): 32 | return int(len(self.meta) * self.sample_weight) 33 | 34 | def __getitem__(self, i): 35 | 36 | if self.sample_weight >= 1: 37 | i = i % len(self.meta) 38 | else: 39 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 40 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 41 | 42 | item = self.meta[i] 43 | question = item['source'] + ' Answer the question and include the reasoning proess.' 44 | thought = item['rationale'] 45 | answer = item['target'] 46 | 47 | conv = [{'from': 'human', 'value': question}, 48 | {'from': 'gpt', 'value': thought}, 49 | {'from': 'human', 'value': 'What is your final answer?'}, 50 | {'from': 'gpt', 'value': answer}] 51 | return {'conversation': conv} -------------------------------------------------------------------------------- /locals/datasets/text/ultrachat.py: -------------------------------------------------------------------------------- 1 | from locals.datasets.text.text_data_base import TextDataset 2 | from pathlib import Path 3 | class UltraChatDataset(TextDataset): 4 | def __init__(self, 5 | path: str, 6 | instruct: bool = False, 7 | sample_weight: float = 1, 8 | output_mode: str = 'conversation', 9 | shuffle: bool = False, 10 | inference: bool = False, 11 | **kwargs): 12 | path = [str(p)for p in Path(path).glob('*.jsonl')] 13 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs) 14 | self.filter_dataset() 15 | print(f"UltraChat Dataset has {len(self)} samples!!") 16 | 17 | def filter_dataset(self,): 18 | new_data = [] 19 | for item in self.meta: 20 | try: 21 | assert len(item['data'])%2==0 22 | new_data.append(item) 23 | except: 24 | continue 25 | self.meta = new_data 26 | 27 | def __getitem__(self, i): 28 | assert self.output_mode == 'conversation' 29 | i = self.get_sampler_index(i) 30 | item = self.meta[i] 31 | conversations = [] 32 | for i in range(0,len(item['data']),2): 33 | conversations.append({'from':'human','value':item['data'][i]}) 34 | conversations.append({'from':'gpt','value':item['data'][i+1]}) 35 | return {'conversation': conversations,'id':'UltraChat'+str(item['id'])} 36 | 37 | class UltraChatJsonDataset(TextDataset): 38 | def __init__(self, 39 | path: str, 40 | instruct: bool = False, 41 | sample_weight: float = 1, 42 | output_mode: str = 'conversation', 43 | shuffle: bool = False, 44 | inference: bool = False, 45 | **kwargs): 46 | path = [path] 47 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs) 48 | self.filter_dataset() 49 | print(f"UltraChat Dataset has {len(self)} samples!!") 50 | 51 | 52 | def filter_dataset(self,): 53 | new_data = [] 54 | length_list = [] 55 | for item in self.meta: 56 | try: 57 | assert len(item['conversations'])%2==0 and len(item['conversations'])!=0 58 | new_data.append(item) 59 | except: 60 | continue 61 | finally: 62 | cur_len = sum(len(conv['value'].split()) for conv in item['conversations']) 63 | cur_len = cur_len if 'image' in item else -cur_len 64 | length_list.append(cur_len) 65 | self.meta = new_data 66 | self.length = length_list 67 | 68 | def __getitem__(self, i): 69 | assert self.output_mode == 'conversation' 70 | item = self.meta[i] 71 | return {'conversation': item['conversations'],'id':'UltraChat'+str(item['id'])} 72 | 73 | 74 | if __name__ == '__main__': 75 | ultrachat_path = '/mnt/bn/yangmin-priv/luoruipu/data/text-dataset/ultrachat/' 76 | test = UltraChatDataset(path=ultrachat_path) 77 | print(len(test)) 78 | print(test[0]) -------------------------------------------------------------------------------- /locals/datasets/text2image/kosmosg.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | from torch.utils.data import Dataset 3 | import json as js 4 | import math 5 | import random 6 | import os 7 | from PIL import Image 8 | from collections import defaultdict 9 | from constants import * 10 | 11 | def load_prompt(fn): 12 | prompts = [] 13 | with open(fn) as f: 14 | line=f.readline() 15 | while line: 16 | line=line.strip('\n') 17 | prompts.append(line) 18 | line=f.readline() 19 | return prompts 20 | 21 | class KosMosGDataset(Dataset): 22 | def __init__(self, 23 | path: str, 24 | image_folder: str, 25 | instruct: bool = False, 26 | min_resize_res: int = 256, 27 | max_resize_res: int = 256, 28 | crop_res: int = 256, 29 | flip_prob: float = 0.5, 30 | sample_weight: float = 1.0, 31 | check: bool = False, 32 | output_mode: str = 'conversation', 33 | raw_image: bool = False, 34 | object_type: str = 'mask', 35 | **kwargs): 36 | # load from json ../datasets/gqa-inpaint/meta_info.json 37 | self.path = path 38 | self.instruct = instruct 39 | self.meta = js.load(open(path)) 40 | self.image_folder = image_folder 41 | self.min_resize_res = min_resize_res 42 | self.max_resize_res = max_resize_res 43 | self.crop_res = crop_res 44 | self.check = check 45 | self.raw_image = raw_image 46 | self.output_mode = output_mode 47 | self.flip_prob = flip_prob 48 | self.sample_weight = sample_weight 49 | self.object_type = object_type 50 | self.prompts = load_prompt('locals/datasets/prompts/prompt_kosmosg.txt') 51 | print(f"KosMos-G has {len(self)} samples!!") 52 | 53 | def __len__(self): 54 | return int(len(self.meta) * self.sample_weight) 55 | 56 | def get_image_fn(self, image_id): 57 | return os.path.join(self.image_folder, image_id) 58 | 59 | def __getitem__(self, i): 60 | 61 | if self.sample_weight >= 1: 62 | i = i % len(self.meta) 63 | else: 64 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) 65 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) 66 | 67 | item = self.meta[i] 68 | caption = item['source'].replace('', ' '+ALL_IMG_TOKENS_STR) 69 | tgt_img = Image.open(item['target']).convert('RGB') 70 | prompt = random.choice(self.prompts) 71 | 72 | # find the caption 73 | if self.object_type == 'random': 74 | prob = random.random() 75 | if prob > 0.5: 76 | object_type = 'mask' 77 | else: 78 | object_type = 'obj' 79 | else: 80 | object_type = self.object_type 81 | if object_type == 'mask': 82 | object_images = [Image.open(img).convert('RGB') for img in item['masks']] 83 | elif object_type == 'obj': 84 | object_images = [Image.open(img).convert('RGB') for img in item['objects']] 85 | else: 86 | raise ValueError 87 | 88 | if self.output_mode == 'conversation': 89 | sources = [{'from': 'human', 'value': prompt.format(desc=caption)}, 90 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 91 | return {'input_images': object_images+[tgt_img], 'conversation': sources, 'image_label_masks': [0]*len(object_images)+[1]} 92 | elif self.output_mode == 'text': 93 | text = caption + ' ' + '{}.'.format(ALL_IMG_TOKENS_STR) 94 | return {'input_images': object_images+[tgt_img], 'text': text, 'image_label_masks': [0]*len(object_images)+[1]} -------------------------------------------------------------------------------- /locals/datasets/text2image/midjourney.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset 4 | from constants import ALL_IMG_TOKENS_STR 5 | import random 6 | class MidjourneyDataset(SingleImageDataset): 7 | def __init__(self, 8 | path: str, 9 | image_folder: str, 10 | instruct_prompt_path: str = None, 11 | instruct: bool = False, 12 | min_resize_res: int = 256, 13 | max_resize_res: int = 256, 14 | crop_res: int = 256, 15 | flip_prob: float = 0.5, 16 | sample_weight: float = 1.0, 17 | check: bool = False, 18 | shuffle: bool = False, 19 | raw_image: bool = False, 20 | inference: bool = False, 21 | min_size: int = 50, 22 | **kwargs): 23 | output_mode = 'conversation' 24 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs) 25 | self.prompt_list = open(instruct_prompt_path).readlines() 26 | print(f"MidjourneyDataset has {len(self)} samples!!") 27 | 28 | def __len__(self): 29 | return int(len(self.meta) * self.sample_weight) 30 | 31 | def __getitem__(self, i): 32 | i = self.get_sampler_index(i) 33 | 34 | item = self.meta[i] 35 | 36 | image_fn = item['id'] 37 | 38 | tgt_img = [Image.open(os.path.join(self.image_folder, image_fn+'.png')).convert('RGB')] 39 | instruct = random.choice(self.prompt_list) 40 | new_conversation = [{'from': 'human', 'value': instruct + ','.join(item['event']['textPrompt'])}, 41 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}] 42 | 43 | return {'output_images': tgt_img, 'conversations': new_conversation,'id':'midjourney_'+item['id'],'image_label_masks': [1]} 44 | 45 | 46 | if __name__ == '__main__': 47 | from tqdm import tqdm 48 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/midjourney_sample_50k_new.json' 49 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/images' 50 | instruct_prompt_path = '/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/locals/datasets/prompts/prompt_txt2img.txt' 51 | test = MidjourneyDataset(path=test_path, image_folder=image_folder,instruct_prompt_path=instruct_prompt_path) 52 | print(test[0]) 53 | # import json as js 54 | # data = js.load(open(test_path)) 55 | # new_data = [] 56 | # for item in tqdm(data): 57 | # image_fn = item['id'] 58 | # try: 59 | # img = Image.open(os.path.join(image_folder, image_fn+'.png')) 60 | # new_data.append(item) 61 | # except: 62 | # continue 63 | # print(len(new_data)) 64 | # js.dump(new_data, open('/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/midjourney_filtered.json','w')) -------------------------------------------------------------------------------- /locals/datasets/utils/box_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import WindowsPath 2 | from constants import * 3 | from utils.eval_util import extract_all_box_str 4 | 5 | def process_thought(thought, mistral=False): 6 | new_thought = thought.replace(' ', '') 7 | # if ' ' in new_thought: 8 | 9 | new_thought = new_thought.replace(' ', '' + ALL_IMG_TOKENS_STR).replace('', '') 10 | all_box = extract_all_box_str(thought, mistral) 11 | return new_thought, all_box 12 | 13 | 14 | def box2str(box, mode='special_tokens', prec=2, space=False): 15 | if mode == 'special_tokens': 16 | # using tokens to represent the locations 17 | num_blocks = len(ALL_LOC_TOKENS) 18 | size_per_block = 1 / num_blocks 19 | block_no = [int(c / size_per_block) if c < 1.0 else len(ALL_LOC_TOKENS)-1 for c in box] 20 | return ''.join([ALL_LOC_TOKENS[i] for i in block_no]) 21 | elif mode == 'text': 22 | # using text to represent the box 23 | if space: 24 | sep = ', ' 25 | else: 26 | sep = ',' 27 | tmp_format = sep.join(['{' + ':.{}f'.format(prec)+'}']*4) 28 | a_box = [float(o) for o in box] 29 | return tmp_format.format(*a_box) 30 | else: 31 | raise NotImplementedError 32 | 33 | def allbox2str(objects): 34 | s = [] 35 | for obj in objects: 36 | s.append('{}: [{}]'.format(obj['class'], box2str(obj['bbox'], 'text', 3, True))) 37 | return ', '.join(s) 38 | 39 | def reshape_box(image, box): 40 | width, height = image.size 41 | abs_box = [c*width if i%2==0 else c*height for i,c in enumerate(box)] 42 | if width == height: 43 | return box 44 | elif width > height: 45 | abs_box[1] += (width - height) // 2 46 | abs_box[3] += (width - height) // 2 47 | max_size = width 48 | else: 49 | abs_box[0] += (height - width) // 2 50 | abs_box[2] += (height - width) // 2 51 | max_size = height 52 | norm_box = [c/max_size for c in abs_box] 53 | return norm_box 54 | 55 | def reshape_box_reverse(image, box): 56 | width, height = image.size 57 | max_side = max(width, height) 58 | abs_box = [c*max_side if i%2==0 else c*max_side for i,c in enumerate(box)] 59 | if width == height: 60 | return box 61 | elif width > height: 62 | abs_box[1] -= (width - height) // 2 63 | abs_box[3] -= (width - height) // 2 64 | max_size = width 65 | else: 66 | abs_box[0] -= (height - width) // 2 67 | abs_box[2] -= (height - width) // 2 68 | max_size = height 69 | norm_box = [c/width if i%2 ==0 else c/height for i,c in enumerate(abs_box)] 70 | return norm_box 71 | 72 | def resize_image_to_square(image): 73 | width, height = image.size 74 | max_side = max(width, height) 75 | image = image.resize((max_side, max_side)) 76 | return image 77 | 78 | 79 | def expand2square_fn(pil_img, background_color): 80 | from PIL import Image 81 | width, height = pil_img.size 82 | if width == height: 83 | return pil_img 84 | elif width > height: 85 | result = Image.new(pil_img.mode, (width, width), background_color) 86 | result.paste(pil_img, (0, (width - height) // 2)) 87 | return result 88 | else: 89 | result = Image.new(pil_img.mode, (height, height), background_color) 90 | result.paste(pil_img, ((height - width) // 2, 0)) 91 | return result -------------------------------------------------------------------------------- /locals/datasets/utils/zip_manager.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os.path as osp 3 | # import lmdb 4 | import logging 5 | from PIL import Image 6 | import pickle 7 | import io 8 | import glob 9 | import os 10 | from pathlib import Path 11 | import time 12 | from threading import Thread 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | home = str(Path.home()) 17 | abs_blob_path=os.path.realpath("/mnt/blob/") 18 | CACHE_FOLDER=os.path.join(home,"caching") 19 | USE_CACHE=True 20 | 21 | def norm(path): 22 | assert "*" not in path 23 | return os.path.realpath(os.path.abspath(path)) 24 | 25 | def in_blob(file): 26 | if abs_blob_path in file: 27 | return True 28 | else: 29 | return False 30 | 31 | def map_name(file): 32 | path=norm(file) 33 | path=path.lstrip(abs_blob_path+"/") 34 | path=path.replace("/","_") 35 | assert len(path)<250 36 | return path 37 | 38 | 39 | def preload(db,sync=False): 40 | if sync: 41 | db.initialize() 42 | else: 43 | p = Thread(target=db.initialize) 44 | p.start() 45 | 46 | def get_keys_from_lmdb(db): 47 | with db.begin(write=False) as txn: 48 | return list(txn.cursor().iternext(values=False)) 49 | 50 | def decode_img(byteflow): 51 | try: 52 | img=Image.open(io.BytesIO(byteflow)).convert("RGB") 53 | img.load() 54 | except: 55 | img = Image.open("white.jpeg").convert("RGB") 56 | img.load() 57 | return img 58 | 59 | def decode_text(byteflow): 60 | return pickle.loads(byteflow) 61 | 62 | decode_funcs={ 63 | "image": decode_img, 64 | "text": decode_text 65 | } 66 | 67 | 68 | class ZipManager: 69 | def __init__(self, zip_path,data_type,prefix=None) -> None: 70 | self.decode_func=decode_funcs[data_type] 71 | self.zip_path=zip_path 72 | self._init=False 73 | preload(self) 74 | 75 | def deinitialze(self): 76 | self.zip_fd.close() 77 | del self.zip_fd 78 | self._init = False 79 | 80 | def initialize(self,close=True): 81 | self.zip_fd = zipfile.ZipFile(self.zip_path, mode="r") 82 | if not hasattr(self,"_keys"): 83 | self._keys = self.zip_fd.namelist() 84 | self._init = True 85 | if close: 86 | self.deinitialze() 87 | 88 | @property 89 | def keys(self): 90 | while not hasattr(self,"_keys"): 91 | time.sleep(0.1) 92 | return self._keys 93 | 94 | def get(self, name): 95 | if not self._init: 96 | self.initialize(close=False) 97 | byteflow = self.zip_fd.read(name) 98 | return self.decode_func(byteflow) 99 | 100 | 101 | class MultipleZipManager: 102 | def __init__(self, files: list, data_type, sync=True): 103 | self.files = files 104 | self._is_init = False 105 | self.data_type=data_type 106 | if sync: 107 | print("sync",files) 108 | self.initialize() 109 | else: 110 | print("async",files) 111 | preload(self) 112 | print("initialize over") 113 | 114 | 115 | def initialize(self): 116 | self.mapping={} 117 | self.managers={} 118 | for file in self.files: 119 | manager = ZipManager(file, self.data_type) 120 | self.managers[file]=manager 121 | 122 | for file,manager in self.managers.items(): 123 | print(file) 124 | # print("loading") 125 | logging.info(f"{file} loading") 126 | keys=manager.keys 127 | for key in keys: 128 | self.mapping[key]=file 129 | logging.info(f"{file} loaded, size = {len(keys)}") 130 | print("loaded") 131 | 132 | self._keys=list(self.mapping.keys()) 133 | self._is_init=True 134 | 135 | @property 136 | def keys(self): 137 | while not self._is_init: 138 | time.sleep(0.1) 139 | return self._keys 140 | 141 | def get(self, name): 142 | data = self.managers[self.mapping[name]].get(name) 143 | return data -------------------------------------------------------------------------------- /model/front_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | from .Qformer import BertConfig, BertLMHeadModel 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"front_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | def init_Qformer(num_query_token, vision_width, cross_attention_freq=2): 33 | encoder_config = BertConfig.from_pretrained("bert-base-uncased") 34 | encoder_config.encoder_width = vision_width 35 | # insert cross-attention layer every other block 36 | encoder_config.add_cross_attention = True 37 | encoder_config.cross_attention_freq = cross_attention_freq 38 | encoder_config.query_length = num_query_token 39 | Qformer = BertLMHeadModel(config=encoder_config) 40 | query_tokens = nn.Parameter( 41 | torch.zeros(1, num_query_token, encoder_config.hidden_size) 42 | ) 43 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 44 | return Qformer, query_tokens 45 | 46 | def build_front_projector(config, delay_load=False, **kwargs): 47 | projector_type = getattr(config, 'front_projector_type', 'linear') 48 | 49 | if projector_type == 'linear': 50 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 51 | elif projector_type == 'q_former': 52 | Qformer, query_tokens = init_Qformer( 53 | kwargs['num_query_token'], kwargs['visual_encoder'].num_features 54 | ) 55 | Qformer.cls = None 56 | Qformer.bert.embeddings.word_embeddings = None 57 | Qformer.bert.embeddings.position_embeddings = None 58 | for layer in Qformer.bert.encoder.layer: 59 | layer.output = None 60 | layer.intermediate = None 61 | checkpoint = torch.load(config.front_projector, map_location="cpu") 62 | state_dict = checkpoint["model"] 63 | new_state_dict = {} 64 | for key in state_dict: 65 | if 'Qformer' in key: 66 | new_state_dict[key.split('Qformer.')[1]] = state_dict[key] 67 | Qformer.load_state_dict(new_state_dict) 68 | return Qformer, query_tokens 69 | 70 | else: 71 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 72 | if mlp_gelu_match: 73 | mlp_depth = int(mlp_gelu_match.group(1)) 74 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 75 | for _ in range(1, mlp_depth): 76 | modules.append(nn.GELU()) 77 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 78 | return nn.Sequential(*modules) 79 | 80 | if projector_type == 'identity': 81 | return IdentityMap() 82 | 83 | raise ValueError(f'Unknown projector type: {projector_type}') 84 | -------------------------------------------------------------------------------- /model/load_model.py: -------------------------------------------------------------------------------- 1 | from genericpath import samestat 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | from typing import Optional, Dict 4 | from dataclasses import dataclass, field 5 | from locals.datasets import SFT_DataCollator, WrappedDataset 6 | from lightning.pytorch import seed_everything 7 | from torchvision import transforms 8 | from constants import * 9 | from PIL import Image 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler 12 | 13 | from locals.datasets.preprocessor import VoCoT_InputProcessor 14 | from omegaconf import OmegaConf 15 | from utils.util import instantiate_from_config 16 | from model.language_model.volcano_llama import VolCanoLlamaForCausalLM,VolCanoConfig 17 | from model.language_model.volcano_mistral import VolCanoMistralForCausalLM, VolCanoMistralConfig 18 | from transformers import LlamaTokenizer, AutoTokenizer 19 | import transformers 20 | from peft import PeftConfig, PeftModel 21 | from argparse import ArgumentParser 22 | import os 23 | import torch.distributed as dist 24 | from utils.logger import setup_logger 25 | import json 26 | import tqdm 27 | 28 | def rank0_print(args, res): 29 | if args.local_rank==0 or args.local_rank == -1: 30 | print(res) 31 | 32 | def get_output_name(args, mid_output=True): 33 | if mid_output: 34 | return os.path.join(args.output_dir, 35 | '{}_rank{}.json'.format(args.dataset_name, args.local_rank)) 36 | else: 37 | return os.path.join(args.output_dir, 38 | '{}.json'.format(args.dataset_name)) 39 | 40 | def get_all_output_names(args): 41 | return [os.path.join(args.output_dir, 42 | '{}_rank{}.json'.format(args.dataset_name, r)) for r in range(args.n_gpus)] 43 | 44 | class CLIPTransform: 45 | def __init__(self, transform, square_size=None): 46 | self.transform = transform 47 | self.square_size = square_size 48 | self.image_mean = transform.image_mean 49 | 50 | def __call__(self, image): 51 | if self.square_size is not None: 52 | image = image.resize((self.square_size, self.square_size)) 53 | try: 54 | tmp = torch.tensor(self.transform(image)['pixel_values'][0]) 55 | except: 56 | tmp = torch.tensor(self.transform(Image.new(image.mode, (32, 32), (0,0,0)))['pixel_values'][0]) 57 | return tmp 58 | 59 | 60 | 61 | def load_model(model_path, device='cuda:0', precision='bf16'): 62 | config_class = VolCanoMistralConfig 63 | model_class = VolCanoMistralForCausalLM 64 | tokenizer_class = AutoTokenizer 65 | device = torch.device(device) 66 | tokenizer = AutoTokenizer.from_pretrained( 67 | model_path, 68 | cache_dir=None, 69 | use_fast=True, 70 | trust_remote_code=True 71 | ) 72 | 73 | llama_config = config_class.from_pretrained(model_path) 74 | model = model_class.from_pretrained(model_path, config=llama_config) 75 | 76 | model.input_img_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMG_TOKEN) 77 | model.eoc_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_EOC_TOKEN) 78 | model.boc_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_BOC_TOKEN) 79 | model.tokenizer = tokenizer 80 | model.sub_image_bind = False 81 | 82 | if precision == 'bf16': 83 | model.to(torch.bfloat16) 84 | elif precision == 'fp16': 85 | model.to(torch.float16) 86 | elif precision == 'fp32': 87 | pass 88 | else: 89 | raise ValueError('precision must be fp16, bf16, or fp32') 90 | model.eval() 91 | model.to(device) 92 | 93 | resize2square = False 94 | output_vis_processor = transforms.Compose( 95 | [ 96 | transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR), 97 | transforms.CenterCrop(1024), 98 | # transforms.RandomHorizontalFlip(), # comment here 99 | transforms.ToTensor(), 100 | transforms.Normalize([0.5], [0.5]), 101 | ] 102 | ) 103 | input_vis_processor = transforms.Compose( 104 | [ 105 | transforms.Resize((448, 448) if resize2square else 448, interpolation=transforms.InterpolationMode.BILINEAR), 106 | transforms.CenterCrop(448), 107 | # transforms.RandomHorizontalFlip(), comment here 108 | transforms.ToTensor(), 109 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 110 | ] 111 | ) 112 | if hasattr(model.vision_encoder, 'image_processor'): 113 | input_vis_processor = model.vision_encoder.image_processor 114 | if resize2square: 115 | tmp_size = input_vis_processor.size['shortest_edge'] 116 | else: 117 | tmp_size = None 118 | input_vis_processor = CLIPTransform(input_vis_processor, square_size=tmp_size) 119 | # tokenizer = LlamaTokenizer.from_pretrained('eval/debug/edit_gpt_emu_tokenizer') 120 | 121 | model.image_processor = None 122 | preprocessor = VoCoT_InputProcessor(tokenizer=tokenizer, input_image_processor = input_vis_processor, use_mistral=True, 123 | output_image_processor= output_vis_processor, merge_in_out_image=True, expand2square=True, inference = True) 124 | 125 | return model, preprocessor 126 | 127 | def infer(model, preprocessor, image, query, cot=True, max_new_tokens=1024, temperature=0.0): 128 | if cot: 129 | query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + '\n' + query + COT_ACTIVATION 130 | else: 131 | query = ALL_IMG_TOKENS_STR + '\n' + query 132 | conv = [{'from': 'human', 'value':query}] 133 | item = {'input_images': [image], 'conversation': conv} 134 | input_item = preprocessor(item) 135 | data_collator = SFT_DataCollator(tokenizer=preprocessor.tokenizer, sd_tokenizer=None) 136 | batch = data_collator([input_item]) 137 | txt_res, out_imgs, txt_ids = model.condition_completion(batch, avoid_image_gen=True, 138 | max_new_tokens=max_new_tokens, temperature=temperature) 139 | 140 | return txt_res 141 | 142 | 143 | if __name__=='__main__': 144 | from PIL import Image 145 | tmp_image = Image.open('eval/debug/tmp.jpg') 146 | model_path = '/mnt/bn/yangmin-priv/luoruipu/checkpoints/LLaVA-clip336px-obj-represent-Mistral-1e-5-3072-instruct_llava+shikraCoT75per+GPTQTA+lvis-cot/' 147 | model, preprocessor = load_model(model_path,precision='fp16') 148 | res1 = infer(model, preprocessor, tmp_image, 'Is there a event "the cat is below the bed" in this image?', cot=True) 149 | res = infer(model, preprocessor, tmp_image, 'Why is the cat on the bed?', cot=True) 150 | res_no_cot = infer(model, preprocessor, tmp_image, 'Describe the image.', cot=True) 151 | print(res1) 152 | print(res) 153 | print(res_no_cot) -------------------------------------------------------------------------------- /model/vision_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .eva_vit import create_eva_vit_g 4 | from .eva_vit_emu import create_eva_vit_emu 5 | 6 | def build_vision_encoder(vision_encoder_cfg, **kwargs): 7 | vision_tower = getattr(vision_encoder_cfg, 'vision_encoder', None) 8 | is_absolute_path_exists = os.path.exists(vision_tower) 9 | 10 | if vision_tower.startswith("openai") or (is_absolute_path_exists and (vision_tower.startswith("OFA-Sys") or vision_tower.startswith("laion"))): 11 | return CLIPVisionTower(vision_tower, args=vision_encoder_cfg, **kwargs) 12 | elif is_absolute_path_exists and 'eva' in vision_tower: 13 | return create_eva_vit_g(vision_encoder_cfg.vision_encoder, **kwargs) 14 | elif vision_tower == 'eva_vit_emu': 15 | # using the emu-pre-trained vit 16 | vision_encoder_path = getattr(vision_encoder_cfg, 'vision_encoder_path', None) 17 | load_encoder_ckpt = not getattr(vision_encoder_cfg, 'skip_load_vision_encoder', False) 18 | assert vision_encoder_path is not None, 'please specify the model path for emu-pre-trained vision encoder' 19 | return create_eva_vit_emu(vision_encoder_path, load_ckpt=load_encoder_ckpt) 20 | raise ValueError(f'Unknown vision tower: {vision_tower}') 21 | -------------------------------------------------------------------------------- /model/vision_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class CLIPVisionTower(nn.Module): 7 | def __init__(self, vision_tower, args, delay_load=False): 8 | super().__init__() 9 | 10 | self.is_loaded = False 11 | 12 | self.vision_tower_name = vision_tower 13 | self.select_layer = args.mm_vision_select_layer 14 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 15 | self.language = getattr(args, 'language', 'english') 16 | if self.language == 'chinese': 17 | from transformers import ChineseCLIPConfig as CLIPVisionConfig 18 | else: 19 | from transformers import CLIPVisionConfig 20 | if not delay_load: 21 | self.load_model() 22 | else: 23 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 24 | self.num_features = self.embed_dim = self.vision_tower.config.hidden_size 25 | 26 | def load_model(self): 27 | if self.language == 'chinese': 28 | from transformers import ChineseCLIPVisionModel as CLIPVisionModel 29 | from transformers import ChineseCLIPImageProcessor as CLIPImageProcessor 30 | else: 31 | from transformers import CLIPVisionModel, CLIPImageProcessor 32 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 33 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 34 | self.vision_tower.requires_grad_(False) 35 | self.image_size = self.vision_tower.config.image_size 36 | 37 | self.is_loaded = True 38 | 39 | def feature_select(self, image_forward_outs): 40 | image_features = image_forward_outs.hidden_states[self.select_layer] 41 | if self.select_feature == 'patch': 42 | image_features = image_features[:, 1:] 43 | elif self.select_feature == 'cls_patch': 44 | image_features = image_features 45 | else: 46 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 47 | return image_features 48 | 49 | @torch.no_grad() 50 | def forward(self, images): 51 | if type(images) is list: 52 | image_features = [] 53 | for image in images: 54 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 55 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 56 | image_features.append(image_feature) 57 | else: 58 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 59 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 60 | 61 | return image_features 62 | 63 | @property 64 | def dummy_feature(self): 65 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 66 | 67 | @property 68 | def dtype(self): 69 | return self.vision_tower.dtype 70 | 71 | @property 72 | def device(self): 73 | return self.vision_tower.device 74 | 75 | @property 76 | def config(self): 77 | if self.is_loaded: 78 | return self.vision_tower.config 79 | else: 80 | return self.cfg_only 81 | 82 | @property 83 | def hidden_size(self): 84 | return self.config.hidden_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==22.2.0 5 | chardet==5.1.0 6 | contourpy==1.0.7 7 | cycler==0.11.0 8 | filelock==3.9.0 9 | fonttools==4.38.0 10 | frozenlist==1.3.3 11 | huggingface-hub 12 | importlib-resources==5.12.0 13 | kiwisolver==1.4.4 14 | matplotlib==3.7.0 15 | multidict==6.0.4 16 | packaging==23.0 17 | psutil==5.9.4 18 | pycocotools==2.0.6 19 | pyparsing==3.0.9 20 | python-dateutil==2.8.2 21 | pyyaml==6.0 22 | regex==2022.10.31 23 | tqdm==4.64.1 24 | timm==0.9.2 25 | spacy==3.5.1 26 | webdataset==0.2.48 27 | scikit-learn 28 | scipy 29 | yarl==1.8.2 30 | zipp==3.14.0 31 | omegaconf==2.3.0 32 | iopath==0.1.10 33 | decord==0.6.0 34 | tenacity==8.2.2 35 | pycocoevalcap 36 | sentence-transformers 37 | umap-learn 38 | notebook 39 | gradio==3.24.1 40 | gradio-client==0.0.8 41 | torch>=2.0 42 | torchvision>=0.15.0 43 | transformers==4.37.2 44 | pyiqa 45 | xformers 46 | accelerate>=0.20.3 47 | peft==0.6.2 48 | tokenizers 49 | lightning>=2.0.2 50 | open_clip_torch 51 | rouge 52 | pytorch-fid 53 | torch-fidelity 54 | torchmetrics 55 | diffusers==0.14.0 56 | opencv-python-headless 57 | prettytable 58 | deepspeed>=0.9.3 59 | datasets 60 | # flash-attn --no-build-isolation -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/__init__.py -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/__init__.pyc -------------------------------------------------------------------------------- /utils/count_line.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--data', type=str, default=None) 6 | 7 | args = parser.parse_args() 8 | res = json.load(open(args.data)) 9 | print(len(res)) -------------------------------------------------------------------------------- /utils/format_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json, os 3 | from PIL import Image 4 | from utils.eval_util import cal_nn_iou, draw_all_box, colors, draw_all_box_colored 5 | 6 | def box2str(box, mode='special_tokens', prec=2, space=False): 7 | if mode == 'text': 8 | # using text to represent the box 9 | if space: 10 | sep = ', ' 11 | else: 12 | sep = ',' 13 | tmp_format = sep.join(['{' + ':.{}f'.format(prec)+'}']*4) 14 | a_box = [float(o) for o in box] 15 | return tmp_format.format(*a_box) 16 | else: 17 | raise NotImplementedError 18 | 19 | def allbox2str(objects, colored=False): 20 | s = [] 21 | for i,obj in enumerate(objects): 22 | if colored: 23 | s.append('{}: {} box [{}]'.format(obj['class'], colors[i], box2str(obj['bbox'], 'text', 3, True))) 24 | else: 25 | s.append('{}: [{}]'.format(obj['class'], box2str(obj['bbox'], 'text', 3, True))) 26 | return ', '.join(s) 27 | 28 | def filter_box(box_list, iou = 0.5): 29 | ''' 30 | box_list = [] 31 | ''' 32 | box_tensor = [] 33 | for i in range(len(box_list)): 34 | box_tensor.append(list(map(float,box_list[i]['bbox']))) 35 | iou_matrix = cal_nn_iou(box_tensor) 36 | result_box = [] 37 | for i in range(len(iou_matrix)): 38 | flag = 0 39 | for j in result_box: 40 | if iou_matrix[i,j]> iou: 41 | flag = 1 42 | break 43 | if flag==0: 44 | result_box.append(i) 45 | result_box = [box_list[i] for i in result_box] 46 | return result_box 47 | 48 | import random 49 | from collections import defaultdict 50 | def merge_object(info): 51 | class2box = defaultdict(list) 52 | for obj in info: 53 | class2box[obj['class']].append([float(t) for t in obj['bbox']]) 54 | return class2box 55 | 56 | class OpenImagesSource: 57 | def __init__(self, path='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/OpenImages/balance_sample_merge_filter_box.json', image_base='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/OpenImages/', draw_box=False): 58 | self.meta = json.load(open(path)) 59 | self.image_base = image_base 60 | self.draw_box = draw_box 61 | 62 | def __len__(self,): 63 | return len(self.meta) 64 | 65 | def __getitem__(self, index): 66 | item = self.meta[index] 67 | split_name = 'train_{}'.format(item['image'][0]) 68 | image = os.path.join(self.image_base, split_name, '{}.jpg'.format(item['image'])) 69 | if self.draw_box: 70 | image = draw_all_box_colored(Image.open(image), item['object']) 71 | return {'id': item['image'], 'image': image, 'box_info':allbox2str(item['object'], colored=self.draw_box)} 72 | 73 | 74 | class LvisSource: 75 | def __init__(self, path='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS/lvis_v1_train_for_generation.json', image_base='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS/train2017', draw_box=False): 76 | self.meta = json.load(open(path)) 77 | self.image_base = image_base 78 | self.draw_box = draw_box 79 | 80 | def __len__(self,): 81 | return len(self.meta) 82 | 83 | def __getitem__(self, index): 84 | item = self.meta[index] 85 | # split_name = 'train_{}'.format(item['image'][0]) 86 | # image = os.path.join(self.image_base, split_name, '{}.jpg'.format(item['image'])) 87 | image = os.path.join(self.image_base, '{:012d}.jpg'.format(int(item['image']))) 88 | if self.draw_box: 89 | image = draw_all_box_colored(Image.open(image), item['object']) 90 | return {'id': item['image'], 'image': image, 'box_info':allbox2str(item['object'], colored=self.draw_box)} 91 | 92 | 93 | if __name__ == "__main__": 94 | d = OpenImagesSource() 95 | print(d[3]) 96 | 97 | -------------------------------------------------------------------------------- /utils/llava_flash_attn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import StreamHandler, Handler, getLevelName 3 | import os 4 | import sys 5 | 6 | 7 | # this class is a copy of logging.FileHandler except we end self.close() 8 | # at the end of each emit. While closing file and reopening file after each 9 | # write is not efficient, it allows us to see partial logs when writing to 10 | # fused Azure blobs, which is very convenient 11 | class FileHandler(StreamHandler): 12 | """ 13 | A handler class which writes formatted logging records to disk files. 14 | """ 15 | def __init__(self, filename, mode='a', encoding=None, delay=False): 16 | """ 17 | Open the specified file and use it as the stream for logging. 18 | """ 19 | # Issue #27493: add support for Path objects to be passed in 20 | filename = os.fspath(filename) 21 | #keep the absolute path, otherwise derived classes which use this 22 | #may come a cropper when the current directory changes 23 | self.baseFilename = os.path.abspath(filename) 24 | self.mode = mode 25 | self.encoding = encoding 26 | self.delay = delay 27 | if delay: 28 | #We don't open the stream, but we still need to call the 29 | #Handler constructor to set level, formatter, lock etc. 30 | Handler.__init__(self) 31 | self.stream = None 32 | else: 33 | StreamHandler.__init__(self, self._open()) 34 | 35 | def close(self): 36 | """ 37 | Closes the stream. 38 | """ 39 | self.acquire() 40 | try: 41 | try: 42 | if self.stream: 43 | try: 44 | self.flush() 45 | finally: 46 | stream = self.stream 47 | self.stream = None 48 | if hasattr(stream, "close"): 49 | stream.close() 50 | finally: 51 | # Issue #19523: call unconditionally to 52 | # prevent a handler leak when delay is set 53 | StreamHandler.close(self) 54 | finally: 55 | self.release() 56 | 57 | def _open(self): 58 | """ 59 | Open the current base file with the (original) mode and encoding. 60 | Return the resulting stream. 61 | """ 62 | return open(self.baseFilename, self.mode, encoding=self.encoding) 63 | 64 | def emit(self, record): 65 | """ 66 | Emit a record. 67 | 68 | If the stream was not opened because 'delay' was specified in the 69 | constructor, open it before calling the superclass's emit. 70 | """ 71 | if self.stream is None: 72 | self.stream = self._open() 73 | StreamHandler.emit(self, record) 74 | self.close() 75 | 76 | def __repr__(self): 77 | level = getLevelName(self.level) 78 | return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level) 79 | 80 | 81 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 82 | logger = logging.getLogger(name) 83 | logger.setLevel(logging.DEBUG) 84 | # don't log results for the non-master process 85 | if distributed_rank > 0: 86 | return logger 87 | ch = logging.StreamHandler(stream=sys.stdout) 88 | ch.setLevel(logging.DEBUG) 89 | # logging.disable(logging.WARNING) 90 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 91 | ch.setFormatter(formatter) 92 | logger.addHandler(ch) 93 | 94 | if save_dir: 95 | fh = FileHandler(os.path.join(save_dir, filename)) 96 | fh.setLevel(logging.DEBUG) 97 | fh.setFormatter(formatter) 98 | logger.addHandler(fh) 99 | 100 | return logger -------------------------------------------------------------------------------- /utils/logger.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/logger.pyc -------------------------------------------------------------------------------- /utils/time_check.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--tgt_time', type=str, default=None) 6 | args = parser.parse_args() 7 | tgt_time = time.strptime(args.tgt_time, '%Y-%m-%d %X') 8 | while True: 9 | current_time = time.localtime() 10 | if current_time >= tgt_time: 11 | print('its high noon!') 12 | break 13 | print('not the right time, currently {}'.format(current_time)) 14 | time.sleep(30) -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import re 6 | import textwrap 7 | import importlib 8 | from prettytable import PrettyTable 9 | import torch.distributed as dist 10 | import transformers 11 | import torch 12 | from safetensors import safe_open 13 | from PIL import Image 14 | from base64 import b64encode, b64decode 15 | import io 16 | 17 | def disable_torch_init(): 18 | """ 19 | Disable the redundant torch default initialization to accelerate model creation. 20 | """ 21 | import torch 22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 24 | 25 | def load_safetensor(path): 26 | tmp_dict = {} 27 | with safe_open(path, framework='pt', device=0) as f: 28 | for k in f.keys(): 29 | tmp_dict[k] = f.get_tensor(k) 30 | return tmp_dict 31 | 32 | 33 | def sanitize_filename(filename): 34 | return re.sub('[^0-9a-zA-Z]+', '_', filename) 35 | 36 | def plot_images_and_text(predicted_image1, predicted_image2, groundtruth_image, generated_text, gt_text, save_dir, task_name, input_texts, input_images): 37 | task_path = os.path.join(save_dir, task_name) 38 | if not os.path.exists(task_path): 39 | os.makedirs(task_path) 40 | max_width = 50 # adjust this value based on your needs 41 | 42 | fig, ax = plt.subplots() 43 | ax.imshow(predicted_image1) 44 | generated_text = generated_text.replace("###", "").replace("[IMG0]", "") 45 | wrapped_generated_text = textwrap.fill(generated_text, max_width) 46 | ax.set_title(wrapped_generated_text, pad=20) 47 | ax.axis('off') 48 | plt.savefig(os.path.join(task_path, f"generated.jpg"), bbox_inches='tight') 49 | plt.close(fig) 50 | 51 | gt_text = gt_text.replace("$", "\$") 52 | wrapped_gt = textwrap.fill(gt_text, max_width) 53 | if predicted_image2 is not None: 54 | fig, ax = plt.subplots() 55 | ax.imshow(predicted_image2) 56 | ax.set_title(wrapped_gt, pad=20) 57 | ax.axis('off') 58 | plt.savefig(os.path.join(task_path, f"sd_baseline.jpg"), bbox_inches='tight') 59 | plt.close(fig) 60 | 61 | if groundtruth_image is not None: 62 | fig, ax = plt.subplots() 63 | groundtruth_image = groundtruth_image.float().cpu().numpy().squeeze() 64 | groundtruth_image = np.transpose(groundtruth_image, (1, 2, 0)) 65 | groundtruth_image = np.uint8(groundtruth_image*255) 66 | ax.imshow(groundtruth_image) 67 | ax.set_title(wrapped_gt, pad=20) 68 | ax.axis('off') 69 | plt.savefig(os.path.join(task_path, f"gt.jpg"), bbox_inches='tight') 70 | plt.close(fig) 71 | 72 | if len(input_texts): 73 | max_width = 30 74 | length = len(input_texts) 75 | if length > 1: 76 | fig, ax = plt.subplots(1, length, figsize=(10*length, 10)) 77 | for i in range(length): 78 | if i < len(input_images): 79 | ax[i].imshow(input_images[i]) 80 | ax[i].set_title(textwrap.fill(input_texts[i], max_width), fontsize=28) 81 | ax[i].axis('off') 82 | else: 83 | ax[i].text(0.5, 0.5, textwrap.fill(input_texts[i], max_width), horizontalalignment='center', verticalalignment='center', fontsize=28) 84 | ax[i].axis('off') 85 | else: 86 | fig, ax = plt.subplots() 87 | ax.imshow(input_images[0]) 88 | ax.set_title(textwrap.fill(input_texts[0], max_width), fontsize=28) 89 | ax.axis('off') 90 | plt.savefig(os.path.join(task_path, f"input.jpg"), bbox_inches='tight') 91 | plt.close(fig) 92 | 93 | return None 94 | 95 | def instantiate_from_config(config, inference = False, reload=False): 96 | if not "target" in config: 97 | if config == '__is_first_stage__': 98 | return None 99 | elif config == "__is_unconditional__": 100 | return None 101 | raise KeyError("Expected key `target` to instantiate.") 102 | return get_obj_from_str(config["target"], reload=reload)(**config.get("params", dict())) 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | def print_trainable_params(model): 113 | if dist.get_rank() == 0: 114 | trainable_params = [k for k,v in model.named_parameters() if v.requires_grad] 115 | trainable_params_group = {} 116 | for para in trainable_params: 117 | layer_num = re.findall(r'layers.(\d+)\.',para) 118 | if layer_num: 119 | cur_layer = int(layer_num[0]) 120 | if para.replace('layers.'+layer_num[0],'layers.*') not in trainable_params_group: 121 | trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')] = layer_num[0] 122 | elif cur_layer > int(trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')]): 123 | trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')] = layer_num[0] 124 | 125 | else: 126 | trainable_params_group[para] = '0' 127 | table = PrettyTable(['Parameter Name','Max Layer Number']) 128 | for key in trainable_params_group.keys(): 129 | table.add_row([key, str(int(trainable_params_group[key])+1)]) 130 | 131 | print(table) 132 | total_num = sum([v.numel() for k,v in model.named_parameters()]) 133 | trainable_num = sum([v.numel() for k,v in model.named_parameters() if v.requires_grad]) 134 | print('Total: {:.2f}M'.format(total_num/1e6), ' Trainable: {:.2f}M'.format(trainable_num/1e6)) 135 | 136 | def rank_0_print(output): 137 | if dist.get_rank() == 0: 138 | print(output) 139 | 140 | def safe_save_model_for_hf_trainer(trainer, 141 | output_dir): 142 | """Collects the state dict and dump to disk.""" 143 | 144 | if trainer.args.lora: 145 | if trainer.args.should_save: 146 | trainer.model.save_pretrained(output_dir) 147 | 148 | else: 149 | if trainer.deepspeed: 150 | print('saving deepspeed model...') 151 | torch.cuda.synchronize() 152 | trainer.save_model(output_dir) 153 | return 154 | 155 | state_dict = trainer.model.state_dict() 156 | if trainer.args.should_save: 157 | cpu_state_dict = { 158 | key: value.cpu() 159 | for key, value in state_dict.items() 160 | } 161 | del state_dict 162 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 163 | 164 | def byte2image(byte_data): 165 | """ 166 | convert byte to PIL image 167 | """ 168 | if isinstance(byte_data, str): 169 | byte_data = b64decode(byte_data) 170 | image = Image.open(io.BytesIO(byte_data)) 171 | return image 172 | -------------------------------------------------------------------------------- /vocot_trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import math 5 | import numpy as np 6 | from transformers import ( 7 | TrainerCallback, 8 | TrainingArguments, 9 | ) 10 | from torch import nn 11 | import datasets 12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 13 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, DistributedSampler 14 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model 15 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES 16 | from transformers import Trainer, is_datasets_available, PreTrainedModel 17 | from transformers.deepspeed import is_deepspeed_zero3_enabled 18 | from transformers.trainer_callback import TrainerControl, TrainerState 19 | from transformers.training_args import TrainingArguments 20 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 21 | import torch.distributed as dist 22 | from transformers.trainer_utils import EvalPrediction,seed_worker 23 | import torch 24 | import re 25 | from transformers.data.data_collator import DataCollator 26 | 27 | 28 | class VoCoTTrainer(Trainer): 29 | def __init__(self, 30 | model: Union[PreTrainedModel, nn.Module] = None, 31 | args: TrainingArguments = None, 32 | data_collator: Optional[DataCollator] = None, 33 | train_dataset: Optional[Dataset] = None, 34 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 35 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 36 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 37 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 38 | callbacks: Optional[List[TrainerCallback]] = None, 39 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 40 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 41 | regression_text_loss_metrics: Optional[Tuple[float]]=(0.0, 0.0)): 42 | self.regression_text_loss_metrics = regression_text_loss_metrics 43 | 44 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics,) 45 | 46 | def log(self, logs: Dict[str, float], eval = False) -> None: 47 | """ 48 | Log `logs` on the various objects watching training. 49 | 50 | Subclass and override this method to inject custom behavior. 51 | 52 | Args: 53 | logs (`Dict[str, float]`): 54 | The values to log. 55 | """ 56 | if not eval: 57 | if self.state.epoch is not None: 58 | logs["epoch"] = round(self.state.epoch, 2) 59 | 60 | # logs['lora_lr'] = self.optimizer.param_groups[0]['lr'] 61 | # logs['other_lr'] = self.optimizer.param_groups[1]['lr'] 62 | txt_loss, reg_loss = self.regression_text_loss_metrics 63 | if txt_loss is not None: 64 | logs['text_loss'] = txt_loss 65 | if reg_loss is not None: 66 | logs['regression_loss'] = reg_loss 67 | output = {**logs, **{"step": self.state.global_step}} 68 | self.state.log_history.append(output) 69 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) 70 | else: 71 | if self.state.epoch is not None: 72 | logs["epoch"] = round(self.state.epoch, 2) 73 | self.state.log_history.append(output) 74 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) 75 | 76 | # def create_optimizer(self,): 77 | # if self.args.lora and self.args.tune_mm_mlp_adapter: 78 | # from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 79 | # from transformers.trainer_pt_utils import get_parameter_names 80 | # decay_parameters = get_parameter_names(self.model, ALL_LAYERNORM_LAYERS) 81 | # decay_parameters = [name for name in decay_parameters if "bias" not in name] 82 | # optimizer_grouped_parameters = [ 83 | # { 84 | # "params": [ 85 | # p for n, p in self.model.named_parameters() if ('lora' in n and p.requires_grad) 86 | # ], 87 | # "weight_decay": self.args.weight_decay, 88 | # 'lr': float(self.args.lora_lr) if self.args.lora_lr else self.args.learning_rate 89 | # }, 90 | # { 91 | # "params": [ 92 | # p for n, p in self.model.named_parameters() if (n in decay_parameters and 'lora' not in n and p.requires_grad) 93 | # ], 94 | # "weight_decay": self.args.weight_decay, 95 | # }, 96 | # { 97 | # "params": [ 98 | # p for n, p in self.model.named_parameters() if (n not in decay_parameters and p.requires_grad) 99 | # ], 100 | # "weight_decay": 0.0, 101 | # }, 102 | # ] 103 | # optimizer_cls, optimizer_kwargs = ValleyTrainer.get_optimizer_cls_and_kwargs(self.args) 104 | # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 105 | # else: 106 | # self.optimizer = super().create_optimizer() 107 | # return self.optimizer 108 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: 109 | if 'gt_label' in inputs: 110 | gt_label = inputs.pop('gt_label') 111 | return super()._prepare_inputs(inputs) 112 | 113 | def compute_loss(self, model, inputs, return_outputs=False): 114 | # return super().compute_loss(model, inputs, return_outputs=False) 115 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 116 | # record the regression and text loss 117 | text_loss = outputs['text_loss'] 118 | if 'regression_loss' in outputs: 119 | regression_loss = outputs['regression_loss'] 120 | else: 121 | regression_loss = None 122 | if self.args.n_gpu > 1: 123 | text_loss = text_loss.mean() 124 | if regression_loss is not None: 125 | regression_loss = regression_loss.mean() 126 | text_loss = text_loss.item() 127 | if regression_loss is not None: 128 | regression_loss = regression_loss.item() 129 | else: 130 | regression_loss = self.regression_text_loss_metrics[1] 131 | self.regression_text_loss_metrics = (text_loss, regression_loss) 132 | return loss 133 | --------------------------------------------------------------------------------